import math as m
import yaml
import logging

from OSMPythonTools.overpass import Overpass, overpassQueryBuilder
from OSMPythonTools.cachingStrategy import CachingStrategy, JSON
from pywikibot import ItemPage, Site
from pywikibot import config
config.put_throttle = 0
config.maxlag = 0

from structs.preferences import Preferences, Preference
from structs.landmarks import Landmark
from utils import take_most_important
import constants


SIGHTSEEING = 'sightseeing'
NATURE = 'nature'
SHOPPING = 'shopping'



class LandmarkManager:

    logger = logging.getLogger(__name__)

    city_bbox_side: int     # bbox side in meters
    radius_close_to: int    # radius in meters
    church_coeff: float     # coeff to adjsut score of churches
    park_coeff: float       # coeff to adjust score of parks
    tag_coeff: float        # coeff to adjust weight of tags
    N_important: int        # number of important landmarks to consider


    def __init__(self) -> None:

        with constants.AMENITY_SELECTORS_PATH.open('r') as f:
            self.amenity_selectors = yaml.safe_load(f)

        with constants.LANDMARK_PARAMETERS_PATH.open('r') as f:
            parameters = yaml.safe_load(f)
            self.city_bbox_side = parameters['city_bbox_side']
            self.radius_close_to = parameters['radius_close_to']
            self.church_coeff = parameters['church_coeff']
            self.park_coeff = parameters['park_coeff']
            self.tag_coeff = parameters['tag_coeff']
            self.N_important = parameters['N_important']

        self.overpass = Overpass()
        CachingStrategy.use(JSON, cacheDir=constants.OSM_CACHE_DIR)


    def generate_landmarks_list(self, center_coordinates: tuple[float, float], preferences: Preferences) -> tuple[list[Landmark], list[Landmark]]:
        """
        Generate and prioritize a list of landmarks based on user preferences.

        This method fetches landmarks from various categories (sightseeing, nature, shopping) based on the user's preferences
        and current location. It scores and corrects these landmarks, removes duplicates, and then selects the most important
        landmarks based on a predefined criterion.

        Parameters:
        center_coordinates (tuple[float, float]): The latitude and longitude of the center location around which to search.
        preferences (Preferences): The user's preference settings that influence the landmark selection.

        Returns:
            tuple[list[Landmark], list[Landmark]]:
                - A list of all existing landmarks.
                - A list of the most important landmarks based on the user's preferences.
        """

        L = []
        bbox = self.create_bbox(center_coordinates)
        # list for sightseeing
        if preferences.sightseeing.score != 0:
            score_function = lambda loc, n_tags: int((self.count_elements_close_to(loc) + ((n_tags**1.2)*self.tag_coeff) )*self.church_coeff)  
            L1 = self.fetch_landmarks(bbox, self.amenity_selectors['sightseeing'], SIGHTSEEING, score_function)
            self.correct_score(L1, preferences.sightseeing)
            L += L1

        # list for nature
        if preferences.nature.score != 0:
            score_function = lambda loc, n_tags: int((self.count_elements_close_to(loc) + ((n_tags**1.2)*self.tag_coeff) )*self.park_coeff)  
            L2 = self.fetch_landmarks(bbox, self.amenity_selectors['nature'], NATURE, score_function)
            self.correct_score(L2, preferences.nature)
            L += L2

        # list for shopping
        if preferences.shopping.score != 0:
            score_function = lambda loc, n_tags: int(self.count_elements_close_to(loc) + ((n_tags**1.2)*self.tag_coeff))
            L3 = self.fetch_landmarks(bbox, self.amenity_selectors['shopping'], SHOPPING, score_function)
            self.correct_score(L3, preferences.shopping)
            L += L3

        L = self.remove_duplicates(L)
        L_constrained = take_most_important(L, self.N_important)
        self.logger.info(f'Generated {len(L)} landmarks around {center_coordinates}, and constrained to {len(L_constrained)} most important ones.')

        return L, L_constrained


    def remove_duplicates(self, landmarks: list[Landmark]) -> list[Landmark]:
        """
        Removes duplicate landmarks based on their names from the given list. Only retains the landmark with highest score

        Parameters:
        landmarks (list[Landmark]): A list of Landmark objects.

        Returns:
        list[Landmark]: A list of unique Landmark objects based on their names.
        """

        L_clean = []
        names = []

        for landmark in landmarks:
            if landmark.name in names: 
                continue  
            else:
                names.append(landmark.name)
                L_clean.append(landmark)
        
        return L_clean
        

    def correct_score(self, landmarks: list[Landmark], preference: Preference):
        """
        Adjust the attractiveness score of each landmark in the list based on user preferences.

        This method updates the attractiveness of each landmark by scaling it according to the user's preference score.
        The score adjustment is computed using a simple linear transformation based on the preference score.

        Args:
            landmarks (list[Landmark]): A list of landmarks whose scores need to be corrected.
            preference (Preference): The user's preference settings that influence the attractiveness score adjustment.

        Raises:
            TypeError: If the type of any landmark in the list does not match the expected type in the preference.
        """

        if len(landmarks) == 0:
            return
        
        if landmarks[0].type != preference.type:
            raise TypeError(f"LandmarkType {preference.type} does not match the type of Landmark {landmarks[0].name}")

        for elem in landmarks:
            elem.attractiveness = int(elem.attractiveness*preference.score/5)     # arbitrary computation


    def count_elements_close_to(self, coordinates: tuple[float, float]) -> int:
        """
        Count the number of OpenStreetMap elements (nodes, ways, relations) within a specified radius of the given location.

        This function constructs a bounding box around the specified coordinates based on the radius. It then queries
        OpenStreetMap data to count the number of elements within that bounding box.

        Args:
            coordinates (tuple[float, float]): The latitude and longitude of the location to search around.

        Returns:
            int: The number of elements (nodes, ways, relations) within the specified radius. Returns 0 if no elements
                are found or if an error occurs during the query.
        """
        
        lat = coordinates[0]
        lon = coordinates[1]

        radius = self.radius_close_to

        alpha = (180*radius) / (6371000*m.pi)
        bbox = {'latLower':lat-alpha,'lonLower':lon-alpha,'latHigher':lat+alpha,'lonHigher': lon+alpha}

        # Build the query to find elements within the radius
        radius_query = overpassQueryBuilder(
            bbox=[bbox['latLower'],
            bbox['lonLower'],
            bbox['latHigher'],
            bbox['lonHigher']],
            elementType=['node', 'way', 'relation']
        )

        try: 
            radius_result = self.overpass.query(radius_query)
            N_elem = radius_result.countWays() + radius_result.countRelations()
            self.logger.debug(f"There are {N_elem} ways/relations within 50m")
            if N_elem is None:
                return 0
            return N_elem
        except:
            return 0


    def create_bbox(self, coordinates: tuple[float, float]) -> tuple[float, float, float, float]:
        """
        Create a bounding box around the given coordinates.

        Args:
            coordinates (tuple[float, float]): The latitude and longitude of the center of the bounding box.

        Returns:
            tuple[float, float, float, float]: The minimum latitude, minimum longitude, maximum latitude, and maximum longitude
                                                defining the bounding box.
        """
        
        lat = coordinates[0]
        lon = coordinates[1]

        # Half the side length in km (since it's a square bbox)
        half_side_length_km = self.city_bbox_side / 2 / 1000

        # Convert distance to degrees
        lat_diff = half_side_length_km / 111  # 1 degree latitude is approximately 111 km
        lon_diff = half_side_length_km / (111 * m.cos(m.radians(lat)))  # Adjust for longitude based on latitude

        # Calculate bbox
        min_lat = lat - lat_diff
        max_lat = lat + lat_diff
        min_lon = lon - lon_diff
        max_lon = lon + lon_diff

        return min_lat, min_lon, max_lat, max_lon


    def fetch_landmarks(self, bbox: tuple, amenity_selector: dict, landmarktype: str, score_function: callable) -> list[Landmark]:
        """
        Fetches landmarks of a specified type from OpenStreetMap (OSM) within a bounding box centered on given coordinates.

        Args:
            bbox (tuple[float, float, float, float]): The bounding box coordinates (min_lat, min_lon, max_lat, max_lon).
            amenity_selector (dict): The Overpass API query selector for the desired landmark type. 
            landmarktype (str): The type of the landmark (e.g., 'sightseeing', 'nature', 'shopping').
            score_function (callable): The function to compute the score of the landmark based on its attributes.

        Returns:
            list[Landmark]: A list of Landmark objects that were fetched and filtered based on the provided criteria.

        Notes:
            - Landmarks are fetched using Overpass API queries.
            - Selectors are translated from the dictionary to the Overpass query format. (e.g., 'amenity'='place_of_worship')
            - Landmarks are filtered based on various conditions including tags and type.
            - Scores are assigned to landmarks based on their attributes and surrounding elements.
        """
        return_list = []

        # caution, when applying a list of selectors, overpass will search for elements that match ALL selectors simultaneously
        # we need to split the selectors into separate queries and merge the results
        for sel in dict_to_selector_list(amenity_selector):
            self.logger.debug(f"Current selector: {sel}")
            query = overpassQueryBuilder(
                bbox = bbox,
                elementType = ['way', 'relation'],
                selector = sel,
                # conditions = [],
                includeCenter = True,
                out = 'body'
                )

            try:
                result = self.overpass.query(query)
            except Exception as e:
                self.logger.error(f"Error fetching landmarks: {e}")
                return
            
            for elem in result.elements():

                name = elem.tag('name')                             # Add name
                location = (elem.centerLat(), elem.centerLon())     # Add coordinates (lat, lon)

                # TODO: exclude these from the get go
                # skip if unprecise location
                if name is None or location[0] is None:
                    continue

                # skip if unused
                if 'disused:leisure' in elem.tags().keys():
                    continue
                
                # skip if part of another building
                if 'building:part' in elem.tags().keys() and elem.tag('building:part') == 'yes':
                    continue
            
                osm_type = elem.type()              # Add type: 'way' or 'relation'
                osm_id = elem.id()                  # Add OSM id 
                elem_type = landmarktype            # Add the landmark type as 'sightseeing, 
                n_tags = len(elem.tags().keys())    # Add number of tags

                # remove specific tags
                skip = False
                for tag in elem.tags().keys():
                    if "pay" in tag:
                        n_tags -= 1             # discard payment options for tags

                    if "disused" in tag:
                        skip = True             # skip disused amenities
                        break

                    if "wikipedia" in tag:
                        n_tags += 3             # wikipedia entries count more

                    if tag == "wikidata":
                        Q = elem.tag('wikidata')
                        site = Site("wikidata", "wikidata")
                        item = ItemPage(site, Q)
                        item.get()
                        n_languages = len(item.labels)
                        n_tags += n_languages/10

                    if elem_type != "nature":
                        if "leisure" in tag and elem.tag('leisure') == "park":
                            elem_type = "nature"

                    if landmarktype != SHOPPING:
                        if "shop" in tag:
                            skip = True
                            break

                        if tag == "building" and elem.tag('building') in ['retail', 'supermarket', 'parking']:
                            skip = True
                            break

                if skip:
                    continue

                score = score_function(location, n_tags)
                if score != 0:
                    # Generate the landmark and append it to the list
                    landmark = Landmark(
                        name=name,
                        type=elem_type,
                        location=location,
                        osm_type=osm_type,
                        osm_id=osm_id,
                        attractiveness=score,
                        must_do=False,
                        n_tags=int(n_tags)
                    )
                    return_list.append(landmark)
        
        self.logger.debug(f"Fetched {len(return_list)} landmarks of type {landmarktype} in {bbox}")

        return return_list



def dict_to_selector_list(d: dict) -> list:
    """
    Convert a dictionary of key-value pairs to a list of Overpass query strings.

    Args:
        d (dict): A dictionary of key-value pairs representing the selector.

    Returns:
        list: A list of strings representing the Overpass query selectors.
    """
    return_list = []
    for key, value in d.items():
        if type(value) == list:
            val = '|'.join(value)
            return_list.append(f'{key}~"{val}"')
        elif type(value) == str and len(value) == 0:
            return_list.append(f'{key}')
        else:
            return_list.append(f'{key}={value}')
    return return_list