import numpy as np
import json, os

from typing import List, Tuple
from scipy.optimize import linprog
from math import radians, sin, cos, acos

from structs.landmarks import Landmark

    
# Function to print the result
def print_res(L: List[Landmark], L_tot) -> list:

    if len(L) == L_tot: 
        print('\nAll landmarks can be visited within max_steps, the following order is suggested : ')
    else :
        print('Could not visit all the landmarks, the following order is suggested : ')

    dist = 0
    for elem in L : 
        if elem.name != 'start' :
            print('- ' + elem.name + ', time to reach = ' + str(elem.time_to_reach))
            dist += elem.time_to_reach
        else : 
            print('- ' + elem.name)

    print("\nMinutes walked : " + str(dist))
    print(f"Visited {len(L)} out of {L_tot} landmarks")


# Prevent the use of a particular set of nodes
def prevent_config(resx, A_ub, b_ub) -> bool:
    
    for i, elem in enumerate(resx):
        resx[i] = round(elem)
    
    N = len(resx)               # Number of edges
    L = int(np.sqrt(N))         # Number of landmarks

    nonzeroind = np.nonzero(resx)[0]                    # the return is a little funky so I use the [0]
    nonzero_tup = np.unravel_index(nonzeroind, (L,L))

    ind_a = nonzero_tup[0].tolist()
    vertices_visited = ind_a
    vertices_visited.remove(0)

    ones = [1]*L
    h = [0]*N
    for i in range(L) :
        if i in vertices_visited :
            h[i*L:i*L+L] = ones

    A_ub = np.vstack((A_ub, h))
    b_ub.append(len(vertices_visited)-1)

    return A_ub, b_ub


# Prevent the possibility of a given set of vertices
def break_cricle(circle_vertices: list, L: int, A_ub: list, b_ub: list) -> bool:

    if L-1 in circle_vertices :
        circle_vertices.remove(L-1)

    h = [0]*L*L
    for i in range(L) :
        if i in circle_vertices :
            h[i*L:i*L+L] = [1]*L

    A_ub = np.vstack((A_ub, h))
    b_ub.append(len(circle_vertices)-1)

    return A_ub, b_ub


# Checks if the path is connected, returns a circle if it finds one
def is_connected(resx) -> bool:
    
    # first round the results to have only 0-1 values
    for i, elem in enumerate(resx):
        resx[i] = round(elem)
    
    N = len(resx)               # length of res
    L = int(np.sqrt(N))         # number of landmarks. CAST INTO INT but should not be a problem because N = L**2 by def.
    n_edges = resx.sum()        # number of edges

    nonzeroind = np.nonzero(resx)[0] # the return is a little funny so I use the [0]

    nonzero_tup = np.unravel_index(nonzeroind, (L,L))

    ind_a = nonzero_tup[0].tolist()
    ind_b = nonzero_tup[1].tolist()

    edges = []
    edges_visited = []
    vertices_visited = []

    edge1 = (ind_a[0], ind_b[0])
    edges_visited.append(edge1)
    vertices_visited.append(edge1[0])

    for i, a in enumerate(ind_a) :
        edges.append((a, ind_b[i]))      # Create the list of edges

    remaining = edges
    remaining.remove(edge1)

    break_flag = False
    while len(remaining) > 0 and not break_flag:
        for edge2 in remaining :
            if edge2[0] == edge1[1] :
                if edge1[1] in vertices_visited :
                    edges_visited.append(edge2)
                    break_flag = True
                    break
                else :    
                    vertices_visited.append(edge1[1])
                    edges_visited.append(edge2)
                    remaining.remove(edge2)
                    edge1 = edge2

            elif edge1[1] == L-1 or edge1[1] in vertices_visited:
                        break_flag = True
                        break

    vertices_visited.append(edge1[1])


    if len(vertices_visited) == n_edges +1 :
        return vertices_visited, []
    else: 
        return vertices_visited, edges_visited


# Function that returns the distance in meters from one location to another
def get_distance(p1: Tuple[float, float], p2: Tuple[float, float], detour: float, speed: float) :
    
    # Compute the straight-line distance in km
    if p1 == p2 :
        return 0, 0
    else: 
        dist = 6371.01 * acos(sin(radians(p1[0]))*sin(radians(p2[0])) + cos(radians(p1[0]))*cos(radians(p2[0]))*cos(radians(p1[1]) - radians(p2[1])))

    # Consider the detour factor for average city
    wdist = dist*detour

    # Time to walk this distance (in minutes)
    wtime = wdist/speed*60

    if wtime > 15 :
        wtime = 5*round(wtime/5)
    else :
        wtime = round(wtime)
  

    return round(wdist, 1), wtime


# Initialize A and c. Compute the distances from all landmarks to each other and store attractiveness
# We want to maximize the sightseeing :  max(c) st. A*x < b   and   A_eq*x = b_eq
def init_ub_dist(landmarks: List[Landmark], max_steps: int):
    
    with open (os.path.dirname(os.path.abspath(__file__)) + '/parameters/optimizer.params', "r") as f :
        parameters = json.loads(f.read())
        detour = parameters['detour factor']
        speed = parameters['average walking speed']
    
    # Objective function coefficients. a*x1 + b*x2 + c*x3 + ...
    c = []
    # Coefficients of inequality constraints (left-hand side)
    A_ub = []

    for spot1 in landmarks :
        dist_table = [0]*len(landmarks)
        c.append(-spot1.attractiveness)
        for j, spot2 in enumerate(landmarks) :
            t = get_distance(spot1.location, spot2.location, detour, speed)[1]
            dist_table[j] = t
        A_ub += dist_table
    c = c*len(landmarks)

    return c, A_ub, [max_steps]


# Constraint to respect max number of travels
def respect_number(L, A_ub, b_ub):

    ones = [1]*L
    zeros = [0]*L
    for i in range(L) :
        h = zeros*i + ones + zeros*(L-1-i)
        A_ub = np.vstack((A_ub, h))
        b_ub.append(1)

    return A_ub, b_ub


# Constraint to not have d14 and d41 simultaneously. Does not prevent circular symmetry with more elements
def break_sym(L, A_ub, b_ub):
    upper_ind = np.triu_indices(L,0,L)

    up_ind_x = upper_ind[0]
    up_ind_y = upper_ind[1]

    for i, _ in enumerate(up_ind_x) :
        l = [0]*L*L
        if up_ind_x[i] != up_ind_y[i] :
            l[up_ind_x[i]*L + up_ind_y[i]] = 1
            l[up_ind_y[i]*L + up_ind_x[i]] = 1

            A_ub = np.vstack((A_ub,l))
            b_ub.append(1)

    return A_ub, b_ub


# Constraint to not stay in position. Removes d11, d22, d33, etc.
def init_eq_not_stay(L: int): 
    l = [0]*L*L

    for i in range(L) :
        for j in range(L) :
            if j == i :
                l[j + i*L] = 1
    
    l = np.array(np.array(l))

    return [l], [0]


# Go through the landmarks and force the optimizer to use landmarks where attractiveness is set to -1
def respect_user_mustsee(landmarks: List[Landmark], A_eq: list, b_eq: list) :
    L = len(landmarks)

    for i, elem in enumerate(landmarks) :
        if elem.must_do is True and elem.name not in ['finish', 'start']:
            l = [0]*L*L
            for j in range(L) :     # sets the horizontal ones (go from)
                l[j +i*L] = 1       # sets the vertical ones (go to)        double check if good
                
            for k in range(L-1) :
                l[k*L+L-1] = 1  

            A_eq = np.vstack((A_eq,l))
            b_eq.append(2)

    return A_eq, b_eq


# Constraint to ensure start at start and finish at goal
def respect_start_finish(L: int, A_eq: list, b_eq: list):
    ls = [1]*L + [0]*L*(L-1)    # sets only horizontal ones for start (go from)
    ljump = [0]*L*L
    ljump[L-1] = 1              # Prevent start finish jump
    lg = [0]*L*L
    ll = [0]*L*(L-1) + [1]*L
    for k in range(L-1) :       # sets only vertical ones for goal (go to)
        ll[k*L] = 1
        if k != 0 :             # Prevent the shortcut start -> finish
            lg[k*L+L-1] = 1 
            

    A_eq = np.vstack((A_eq,ls))
    A_eq = np.vstack((A_eq,ljump))
    A_eq = np.vstack((A_eq,lg))
    A_eq = np.vstack((A_eq,ll))
    b_eq.append(1)
    b_eq.append(0)
    b_eq.append(1)
    b_eq.append(0)

    return A_eq, b_eq


# Constraint to tie the problem together. Necessary but not sufficient to avoid circles
def respect_order(N: int, A_eq, b_eq): 
    for i in range(N-1) :           # Prevent stacked ones
        if i == 0 or i == N-1:      # Don't touch start or finish
            continue
        else : 
            l = [0]*N
            l[i] = -1
            l = l*N
            for j in range(N) :
                l[i*N + j] = 1

            A_eq = np.vstack((A_eq,l))
            b_eq.append(0)

    return A_eq, b_eq


# Computes the path length given path matrix (dist_table) and a result
def add_time_to_reach(order: List[Landmark], landmarks: List[Landmark])->List[Landmark] :
    
    j = 0
    L =  []
    
    # Read the parameters from the file
    with open (os.path.dirname(os.path.abspath(__file__)) + '/parameters/optimizer.params', "r") as f :
        parameters = json.loads(f.read())
        detour = parameters['detour factor']
        speed = parameters['average walking speed']
    
    prev = landmarks[0]
    while(len(L) != len(order)) :
        
        elem = landmarks[order[j]]
        if elem != prev :
            elem.time_to_reach = get_distance(elem.location, prev.location, detour, speed)[1]
        L.append(elem)
        prev = elem   
        j += 1

    return L


# Main optimization pipeline
def solve_optimization (landmarks :List[Landmark], max_steps: int, printing_details: bool) :

    L = len(landmarks)

    # SET CONSTRAINTS FOR INEQUALITY
    c, A_ub, b_ub = init_ub_dist(landmarks, max_steps)              # Add the distances from each landmark to the other
    A_ub, b_ub = respect_number(L, A_ub, b_ub)                      # Respect max number of visits (no more possible stops than landmarks). 
    A_ub, b_ub = break_sym(L, A_ub, b_ub)                           # break the 'zig-zag' symmetry

    # SET CONSTRAINTS FOR EQUALITY
    A_eq, b_eq = init_eq_not_stay(L)                                # Force solution not to stay in same place
    A_eq, b_eq = respect_user_mustsee(landmarks, A_eq, b_eq)     # Check if there are user_defined must_see. Also takes care of start/goal
    A_eq, b_eq = respect_start_finish(L, A_eq, b_eq)                # Force start and finish positions
    A_eq, b_eq = respect_order(L, A_eq, b_eq)                       # Respect order of visit (only works when max_steps is limiting factor)

    # SET BOUNDS FOR DECISION VARIABLE (x can only be 0 or 1)
    x_bounds = [(0, 1)]*L*L

    # Solve linear programming problem
    res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq = b_eq, bounds=x_bounds, method='highs', integrality=3)

    # Raise error if no solution is found
    if not res.success :
            raise ArithmeticError("No solution could be found, the problem is overconstrained. Please adapt your must_dos")

    # If there is a solution, we're good to go, just check for connectiveness
    else :
        order, circle = is_connected(res.x)
        i = 0
        timeout = 300
        while len(circle) != 0 and i < timeout:
            A_ub, b_ub = prevent_config(res.x, A_ub, b_ub)
            A_ub, b_ub = break_cricle(order, len(landmarks), A_ub, b_ub)
            res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq = b_eq, bounds=x_bounds, method='highs', integrality=3)
            order, circle = is_connected(res.x)
            if len(circle) == 0 :
                # Add the times to reach and stop optimizing
                L = add_time_to_reach(order, landmarks)
                break
            print(i)
            i += 1
        
        if i == timeout :
            raise TimeoutError(f"Optimization took too long. No solution found after {timeout} iterations.")

        if printing_details is True :
            if i != 0 :
                print(f"Neded to recompute paths {i} times because of unconnected loops...")
            print_res(L, len(landmarks))
            print("\nTotal score : " + str(int(-res.fun)))

        return L