import yaml, logging
import numpy as np

from scipy.optimize import linprog
from collections import defaultdict, deque

from ..structs.landmark import Landmark
from .get_time_separation import get_time
from ..constants import OPTIMIZER_PARAMETERS_PATH

    



class Optimizer:

    logger = logging.getLogger(__name__)

    detour: int = None              # accepted max detour time (in minutes)
    detour_factor: float            # detour factor of straight line vs real distance in cities
    average_walking_speed: float    # average walking speed of adult
    max_landmarks: int              # max number of landmarks to visit
    overshoot: float                # overshoot to allow maxtime to overflow. Optimizer is a bit restrictive


    def __init__(self) :

        # load parameters from file
        with OPTIMIZER_PARAMETERS_PATH.open('r') as f:
            parameters = yaml.safe_load(f)
            self.detour_factor = parameters['detour_factor']
            self.average_walking_speed = parameters['average_walking_speed']
            self.max_landmarks = parameters['max_landmarks']
            self.overshoot = parameters['overshoot']
        


    # Prevent the use of a particular solution
    def prevent_config(self, resx):
        """
        Prevent the use of a particular solution by adding constraints to the optimization.

        Args:
            resx (list[float]): List of edge weights.

        Returns:
            tuple[list[int], list[int]]: A tuple containing a new row for constraint matrix and new value for upper bound vector.
        """
        
        for i, elem in enumerate(resx):
            resx[i] = round(elem)
        
        N = len(resx)               # Number of edges
        L = int(np.sqrt(N))         # Number of landmarks

        nonzeroind = np.nonzero(resx)[0]                    # the return is a little funky so I use the [0]
        nonzero_tup = np.unravel_index(nonzeroind, (L,L))

        ind_a = nonzero_tup[0].tolist()
        vertices_visited = ind_a
        vertices_visited.remove(0)

        ones = [1]*L
        h = [0]*N
        for i in range(L) :
            if i in vertices_visited :
                h[i*L:i*L+L] = ones

        return h, [len(vertices_visited)-1]


    # Prevents the creation of the same circle (both directions)
    def prevent_circle(self, circle_vertices: list, L: int) :
        """
        Prevent circular paths by by adding constraints to the optimization.

        Args:
            circle_vertices (list): List of vertices forming a circle.
            L (int): Number of landmarks.

        Returns:
            tuple[np.ndarray, list[int]]: A tuple containing a new row for constraint matrix and new value for upper bound vector.
        """

        l1 = [0]*L*L
        l2 = [0]*L*L
        for i, node in enumerate(circle_vertices[:-1]) :
            next = circle_vertices[i+1]

            l1[node*L + next] = 1
            l2[next*L + node] = 1

        s = circle_vertices[0]
        g = circle_vertices[-1]

        l1[g*L + s] = 1
        l2[s*L + g] = 1

        return np.vstack((l1, l2)), [0, 0]


    def is_connected(self, resx) :
        """
        Determine the order of visits and detect any circular paths in the given configuration.

        Args:
            resx (list): List of edge weights.

        Returns:
            tuple[list[int], Optional[list[list[int]]]]: A tuple containing the visit order and a list of any detected circles.
        """

        # first round the results to have only 0-1 values
        for i, elem in enumerate(resx):
            resx[i] = round(elem)
        
        N = len(resx)               # length of res
        L = int(np.sqrt(N))         # number of landmarks. CAST INTO INT but should not be a problem because N = L**2 by def.

        nonzeroind = np.nonzero(resx)[0] # the return is a little funny so I use the [0]
        nonzero_tup = np.unravel_index(nonzeroind, (L,L))

        ind_a = nonzero_tup[0].tolist()
        ind_b = nonzero_tup[1].tolist()

        # Step 1: Create a graph representation
        graph = defaultdict(list)
        for a, b in zip(ind_a, ind_b):
            graph[a].append(b)

        # Step 2: Function to perform BFS/DFS to extract journeys
        def get_journey(start):
            journey_nodes = []
            visited = set()
            stack = deque([start])

            while stack:
                node = stack.pop()
                if node not in visited:
                    visited.add(node)
                    journey_nodes.append(node)
                    for neighbor in graph[node]:
                        if neighbor not in visited:
                            stack.append(neighbor)

            return journey_nodes

        # Step 3: Extract all journeys
        all_journeys_nodes = []
        visited_nodes = set()

        for node in ind_a:
            if node not in visited_nodes:
                journey_nodes = get_journey(node)
                all_journeys_nodes.append(journey_nodes)
                visited_nodes.update(journey_nodes)

        for l in all_journeys_nodes :
            if 0 in l :
                order = l
                all_journeys_nodes.remove(l)
                break

        if len(all_journeys_nodes) == 0 :
            return order, None

        return order, all_journeys_nodes



    def init_ub_dist(self, landmarks: list[Landmark], max_time: int):
        """
        Initialize the objective function coefficients and inequality constraints for the optimization problem.

        This function computes the distances between all landmarks and stores their attractiveness to maximize sightseeing. 
        The goal is to maximize the objective function subject to the constraints A*x < b and A_eq*x = b_eq.

        Args:
            landmarks (list[Landmark]): List of landmarks.
            max_time (int): Maximum time of visit allowed.

        Returns:
            tuple[list[float], list[float], list[int]]: Objective function coefficients, inequality constraint coefficients, and the right-hand side of the inequality constraint.
        """
        
        # Objective function coefficients. a*x1 + b*x2 + c*x3 + ...
        c = []
        # Coefficients of inequality constraints (left-hand side)
        A_ub = []

        for spot1 in landmarks :
            dist_table = [0]*len(landmarks)
            c.append(-spot1.attractiveness)
            for j, spot2 in enumerate(landmarks) :
                t = get_time(spot1.location, spot2.location) + spot1.duration
                dist_table[j] = t
            closest = sorted(dist_table)[:25]
            for i, dist in enumerate(dist_table) :
                if dist not in closest :
                    dist_table[i] = 32700
            A_ub += dist_table
        c = c*len(landmarks)

        return c, A_ub, [max_time*self.overshoot]


    def respect_number(self, L, max_landmarks: int):
        """
        Generate constraints to ensure each landmark is visited only once and cap the total number of visited landmarks.

        Args:
            L (int): Number of landmarks.

        Returns:
            tuple[np.ndarray, list[int]]: Inequality constraint coefficients and the right-hand side of the inequality constraints.
        """

        ones = [1]*L
        zeros = [0]*L
        A = ones + zeros*(L-1)
        b = [1]
        for i in range(L-1) :
            h_new = zeros*i + ones + zeros*(L-1-i)
            A = np.vstack((A, h_new))
            b.append(1)

        A = np.vstack((A, ones*L))
        b.append(max_landmarks+1)

        return A, b


    # Constraint to not have d14 and d41 simultaneously. Does not prevent cyclic paths with more elements
    def break_sym(self, L):
        """
        Generate constraints to prevent simultaneous travel between two landmarks in both directions.

        Args:
            L (int): Number of landmarks.

        Returns:
            tuple[np.ndarray, list[int]]: Inequality constraint coefficients and the right-hand side of the inequality constraints.
        """

        upper_ind = np.triu_indices(L,0,L)

        up_ind_x = upper_ind[0]
        up_ind_y = upper_ind[1]

        A = [0]*L*L
        b = [1]

        for i, _ in enumerate(up_ind_x[1:]) :
            l = [0]*L*L
            if up_ind_x[i] != up_ind_y[i] :
                l[up_ind_x[i]*L + up_ind_y[i]] = 1
                l[up_ind_y[i]*L + up_ind_x[i]] = 1

                A = np.vstack((A,l))
                b.append(1)

        return A, b


    def init_eq_not_stay(self, L: int): 
        """
        Generate constraints to prevent staying in the same position (e.g., removing d11, d22, d33, etc.).

        Args:
            L (int): Number of landmarks.

        Returns:
            tuple[list[np.ndarray], list[int]]: Equality constraint coefficients and the right-hand side of the equality constraints.
        """

        l = [0]*L*L

        for i in range(L) :
            for j in range(L) :
                if j == i :
                    l[j + i*L] = 1
        
        l = np.array(np.array(l), dtype=np.int8)

        return [l], [0]


    def respect_user_must_do(self, landmarks: list[Landmark]) :
        """
        Generate constraints to ensure that landmarks marked as 'must_do' are included in the optimization.

        Args:
            landmarks (list[Landmark]): List of landmarks, where some are marked as 'must_do'.

        Returns:
            tuple[np.ndarray, list[int]]: Inequality constraint coefficients and the right-hand side of the inequality constraints.
        """

        L = len(landmarks)
        A = [0]*L*L
        b = [0]

        for i, elem in enumerate(landmarks[1:]) :
            if elem.must_do is True and elem.name not in ['finish', 'start']:
                l = [0]*L*L
                l[i*L:i*L+L] = [1]*L        # set mandatory departures from landmarks tagged as 'must_do'

                A = np.vstack((A,l))
                b.append(1)

        return A, b
    

    def respect_user_must_avoid(self, landmarks: list[Landmark]) :
        """
        Generate constraints to ensure that landmarks marked as 'must_avoid' are skipped in the optimization.

        Args:
            landmarks (list[Landmark]): List of landmarks, where some are marked as 'must_avoid'.

        Returns:
            tuple[np.ndarray, list[int]]: Inequality constraint coefficients and the right-hand side of the inequality constraints.
        """

        L = len(landmarks)
        A = [0]*L*L
        b = [0]

        for i, elem in enumerate(landmarks[1:]) :
            if elem.must_avoid is True and elem.name not in ['finish', 'start']:
                l = [0]*L*L
                l[i*L:i*L+L] = [1]*L        

                A = np.vstack((A,l))
                b.append(0)             # prevent departures from landmarks tagged as 'must_do'

        return A, b


    # Constraint to ensure start at start and finish at goal
    def respect_start_finish(self, L: int):
        """
        Generate constraints to ensure that the optimization starts at the designated start landmark and finishes at the goal landmark.

        Args:
            L (int): Number of landmarks.

        Returns:
            tuple[np.ndarray, list[int]]: Inequality constraint coefficients and the right-hand side of the inequality constraints.
        """

        l_start = [1]*L + [0]*L*(L-1)   # sets departures only for start (horizontal ones)
        l_start[L-1] = 0                # prevents the jump from start to finish
        l_goal = [0]*L*L                # sets arrivals only for finish (vertical ones)
        l_L = [0]*L*(L-1) + [1]*L       # prevents arrivals at start and departures from goal
        for k in range(L-1) :           # sets only vertical ones for goal (go to)
            l_L[k*L] = 1
            if k != 0 :
                l_goal[k*L+L-1] = 1     

        A = np.vstack((l_start, l_goal))
        b = [1, 1]
        A = np.vstack((A,l_L))
        b.append(0)

        return A, b


    def respect_order(self, L: int): 
        """
        Generate constraints to tie the optimization problem together and prevent stacked ones, although this does not fully prevent circles.

        Args:
            L (int): Number of landmarks.

        Returns:
            tuple[np.ndarray, list[int]]: Inequality constraint coefficients and the right-hand side of the inequality constraints.
        """

        A = [0]*L*L
        b = [0]
        for i in range(L-1) :           # Prevent stacked ones
            if i == 0 or i == L-1:      # Don't touch start or finish
                continue
            else : 
                l = [0]*L
                l[i] = -1
                l = l*L
                for j in range(L) :
                    l[i*L + j] = 1

                A = np.vstack((A,l))
                b.append(0)

        return A, b


    def link_list(self, order: list[int], landmarks: list[Landmark])->list[Landmark] :
        """
        Compute the time to reach from each landmark to the next and create a list of landmarks with updated travel times.

        Args:
            order (list[int]): List of indices representing the order of landmarks to visit.
            landmarks (list[Landmark]): List of all landmarks.

        Returns:
            list[Landmark]]: The updated linked list of landmarks with travel times
        """
        
        L =  []
        j = 0
        while j < len(order)-1 :
            # get landmarks involved
            elem = landmarks[order[j]]
            next = landmarks[order[j+1]]

            # get attributes
            elem.time_to_reach_next = get_time(elem.location, next.location)
            elem.must_do = True
            elem.location = (round(elem.location[0], 5), round(elem.location[1], 5))
            elem.next_uuid = next.uuid
            L.append(elem)
            j += 1

        next.location = (round(next.location[0], 5), round(next.location[1], 5))
        next.must_do = True   
        L.append(next)
        
        return L


    # Main optimization pipeline
    def solve_optimization(
            self,
            max_time: int,
            landmarks: list[Landmark],
            max_landmarks: int = None
        ) -> list[Landmark]:
        """
        Main optimization pipeline to solve the landmark visiting problem.

        This method sets up and solves a linear programming problem with constraints to find an optimal tour of landmarks,
        considering user-defined must-visit landmarks, start and finish points, and ensuring no cycles are present.

        Args:
            max_time (int): Maximum time allowed for the tour in minutes.
            landmarks (list[Landmark]): List of landmarks to visit.
            max_landmarks (int): Maximum number of landmarks visited
        Returns:
            list[Landmark]: The optimized tour of landmarks with updated travel times, or None if no valid solution is found.
        """
        if max_landmarks is None :
            max_landmarks = self.max_landmarks

        L = len(landmarks)

        # SET CONSTRAINTS FOR INEQUALITY
        c, A_ub, b_ub = self.init_ub_dist(landmarks, max_time)          # Add the distances from each landmark to the other
        A, b = self.respect_number(L, max_landmarks)                                   # Respect max number of visits (no more possible stops than landmarks). 
        A_ub = np.vstack((A_ub, A), dtype=np.int16)
        b_ub += b
        A, b = self.break_sym(L)                                         # break the 'zig-zag' symmetry
        A_ub = np.vstack((A_ub, A), dtype=np.int16)
        b_ub += b


        # SET CONSTRAINTS FOR EQUALITY
        A_eq, b_eq = self.init_eq_not_stay(L)                            # Force solution not to stay in same place
        A, b = self.respect_user_must_do(landmarks)                      # Check if there are user_defined must_see. Also takes care of start/goal
        A_eq = np.vstack((A_eq, A), dtype=np.int8)
        b_eq += b
        A, b = self.respect_user_must_avoid(landmarks)                      # Check if there are user_defined must_see. Also takes care of start/goal
        A_eq = np.vstack((A_eq, A), dtype=np.int8)
        b_eq += b
        A, b = self.respect_start_finish(L)                  # Force start and finish positions
        A_eq = np.vstack((A_eq, A), dtype=np.int8)
        b_eq += b
        A, b = self.respect_order(L)                         # Respect order of visit (only works when max_time is limiting factor)
        A_eq = np.vstack((A_eq, A), dtype=np.int8)
        b_eq += b
        
        # SET BOUNDS FOR DECISION VARIABLE (x can only be 0 or 1)
        x_bounds = [(0, 1)]*L*L

        # Solve linear programming problem
        res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq = b_eq, bounds=x_bounds, method='highs', integrality=3)

        # Raise error if no solution is found
        if not res.success :
            raise ArithmeticError("No solution could be found, the problem is overconstrained. Try with a longer trip (>30 minutes).")

        # If there is a solution, we're good to go, just check for connectiveness
        order, circles = self.is_connected(res.x)
        #nodes, edges = is_connected(res.x)
        i = 0
        timeout = 80
        while circles is not None and i < timeout:
            A, b = self.prevent_config(res.x)
            A_ub = np.vstack((A_ub, A))
            b_ub += b
            #A_ub, b_ub = prevent_circle(order, len(landmarks), A_ub, b_ub)
            for circle in circles :
                A, b = self.prevent_circle(circle, L)
                A_eq = np.vstack((A_eq, A))
                b_eq += b
            res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq = b_eq, bounds=x_bounds, method='highs', integrality=3)
            if not res.success :
                raise ArithmeticError("Solving failed because of overconstrained problem")
                return None
            order, circles = self.is_connected(res.x)
            #nodes, edges = is_connected(res.x)
            if circles is None :
                break
            # print(i)
            i += 1
        
        if i == timeout :
            raise TimeoutError(f"Optimization took too long. No solution found after {timeout} iterations.")

        #sort the landmarks in the order of the solution
        tour =  [landmarks[i] for i in order] 
        
        self.logger.debug(f"Re-optimized {i} times, score: {int(-res.fun)}")
        return tour