"""Module allowing connexion to overpass api and fectch data from OSM."""
import os
import urllib
import math
import logging
import xml.etree.ElementTree as ET

from .caching_strategy import get_cache_key, CachingStrategy
from ..constants import OSM_CACHE_DIR, OSM_TYPES


RESOLUTION = 0.05


class Overpass :
    """
    Overpass class to manage the query building and sending to overpass api.
    The caching strategy is a part of this class and initialized upon creation of the Overpass object.
    """
    logger = logging.getLogger(__name__)


    def __init__(self, caching_strategy: str = 'XML', cache_dir: str = OSM_CACHE_DIR) :
        """
        Initialize the Overpass instance with the url, headers and caching strategy.
        """
        self.overpass_url = "https://overpass-api.de/api/interpreter"
        self.headers = {'User-Agent': 'Mozilla/5.0 (compatible; OverpassQuery/1.0; +http://example.com)',}
        self.caching_strategy = CachingStrategy.use(caching_strategy, cache_dir=cache_dir)


    def send_query(self, bbox: tuple, osm_types: OSM_TYPES,
                    selector: str, conditions=[], out='center') -> ET:
        """
        Sends the Overpass QL query to the Overpass API and returns the parsed JSON response.

        Args:
            query (str): The Overpass QL query to be sent to the Overpass API.

        Returns:
            dict: The parsed JSON response from the Overpass API, or None if the request fails.
        """
        # Determine which grid cells overlap with this bounding box.
        overlapping_cells = Overpass.get_overlapping_cells(bbox)

        # Check the cache for any data that overlaps with these cells
        cell_key_dict = {}
        for cell in overlapping_cells :
            for elem in osm_types :
                key_str = f"{elem}[{selector}]{conditions}({','.join(map(str, cell))})"
            
            cell_key_dict[cell] = get_cache_key(key_str)

        cached_responses = []
        hollow_cache_keys = []

        # Retrieve the cached data and mark the missing entries as hollow
        for cell, key in cell_key_dict.items():
            cached_data = self.caching_strategy.get(key)
            if cached_data is not None :
                cached_responses.append(cached_data)
            else:
                # Cache miss: Mark the cache key as hollow
                self.caching_strategy.set_hollow(key, cell, osm_types, selector, conditions, out)
                hollow_cache_keys.append(key)

        # If there is no missing data, return the cached responses
        if not hollow_cache_keys :
            self.logger.debug(f'Cache hit.')
            return self.combine_cached_data(cached_responses)
        
        # TODO If there is SOME missing data : hybrid stuff with partial cache
        
        # Build the query string in case of needed overpass query
        query_str = Overpass.build_query(bbox, osm_types, selector, conditions, out)

        # Prepare the data to be sent as POST request, encoded as bytes
        data = urllib.parse.urlencode({'data': query_str}).encode('utf-8')

        try:
            # Create a Request object with the specified URL, data, and headers
            request = urllib.request.Request(self.overpass_url, data=data, headers=self.headers)

            # Send the request and read the response
            with urllib.request.urlopen(request) as response:
                # Read and decode the response
                response_data = response.read().decode('utf-8')
                root = ET.fromstring(response_data)

                self.logger.debug(f'Cache miss. Fetching data through Overpass\nQuery = {query_str}')

                return root

        except urllib.error.URLError as e:
            raise ConnectionError(f"Error connecting to Overpass API: {e}") from e


    def fill_cache(self, xml_string: str) :
        
        # Build the query using info from hollow cache entry
        query_str, cache_key = Overpass.build_query_from_hollow(xml_string)

        # Prepare the data to be sent as POST request, encoded as bytes
        data = urllib.parse.urlencode({'data': query_str}).encode('utf-8')

        try:
            # Create a Request object with the specified URL, data, and headers
            request = urllib.request.Request(self.overpass_url, data=data, headers=self.headers)

            # Send the request and read the response
            with urllib.request.urlopen(request) as response:
                # Read and decode the response
                response_data = response.read().decode('utf-8')
                root = ET.fromstring(response_data)

                self.caching_strategy.set(cache_key, root)
                self.logger.debug(f'Cache set')

        except urllib.error.URLError as e:
            raise ConnectionError(f"Error connecting to Overpass API: {e}") from e


    @staticmethod
    def build_query(bbox: tuple, osm_types: OSM_TYPES,
                    selector: str, conditions=[], out='center') -> str:
        """
        Constructs a query string for the Overpass API to retrieve OpenStreetMap (OSM) data.

        Args:
            bbox (tuple): A tuple representing the geographical search area, typically in the format 
                        (lat_min, lon_min, lat_max, lon_max).
            osm_types (list[str]): A list of OSM element types to search for. Must be one or more of 
                                    'Way', 'Node', or 'Relation'.
            selector (str): The key or tag to filter the OSM elements (e.g., 'amenity', 'highway', etc.).
            conditions (list, optional): A list of conditions to apply as additional filters for the 
                                        selected OSM elements. The conditions should be written in 
                                        the Overpass QL format, and they are combined with '&&' if 
                                        multiple are provided. Defaults to an empty list.
            out (str, optional): Specifies the output type, such as 'center', 'body', or 'tags'. 
                                Defaults to 'center'.

        Returns:
            str: The constructed Overpass QL query string.

        Notes:
            - If no conditions are provided, the query will just use the `selector` to filter the OSM 
            elements without additional constraints.
        """
        if not isinstance(conditions, list) :
            conditions = [conditions]
        if not isinstance(osm_types, list) :
            osm_types = [osm_types]

        query = '('

        # convert the bbox to string.
        bbox_str = f"({','.join(map(str, bbox))})"

        if conditions :
            conditions = '(if: ' + ' && '.join(conditions) + ')'
        else :
            conditions = ''

        for elem in osm_types :
            query += elem + '[' + selector + ']' + conditions + bbox_str + ';'

        query += ');' + f'out {out};'

        return query


    @staticmethod
    def build_query_from_hollow(xml_string):
        """Extract variables from an XML string."""
        
        # Parse the XML string into an ElementTree object
        root = ET.fromstring(xml_string)
        
        # Extract values from the XML tree
        key = root.find('key').text
        cell = tuple(map(float, root.find('cell').text.strip('()').split(',')))
        bbox = Overpass.get_bbox_from_grid_cell(cell[0], cell[1])
        osm_types = root.find('osm_types').text.split(',')
        selector = root.find('selector').text
        conditions = root.find('conditions').text.split(',') if root.find('conditions').text != "none" else []
        out = root.find('out').text

        query_str = Overpass.build_query(bbox, osm_types, selector, conditions, out)
        
        return query_str, key


    @staticmethod
    def get_grid_cell(lat: float, lon: float):
        """
        Returns the grid cell coordinates for a given latitude and longitude.
        Each grid cell is 0.05°lat x 0.05°lon resolution in size.
        """
        lat_index = math.floor(lat / RESOLUTION)
        lon_index = math.floor(lon / RESOLUTION)
        return (lat_index, lon_index)


    @staticmethod
    def get_bbox_from_grid_cell(lat_index: int, lon_index: int):
        """
        Returns the bounding box for a given grid cell index.
        Each grid cell is resolution x resolution in size.

        The bounding box is returned as (min_lat, min_lon, max_lat, max_lon).
        """
        # Calculate the southwest (min_lat, min_lon) corner of the bounding box
        min_lat = round(lat_index * RESOLUTION, 2)
        min_lon = round(lon_index * RESOLUTION, 2)

        # Calculate the northeast (max_lat, max_lon) corner of the bounding box
        max_lat = round((lat_index + 1) * RESOLUTION, 2)
        max_lon = round((lon_index + 1) * RESOLUTION, 2)

        return (min_lat, min_lon, max_lat, max_lon)


    @staticmethod
    def get_overlapping_cells(query_bbox: tuple):
        """
        Returns a set of all grid cells that overlap with the given bounding box.
        """
        # Extract location from the query bbox
        lat_min, lon_min, lat_max, lon_max = query_bbox

        min_lat_cell, min_lon_cell = Overpass.get_grid_cell(lat_min, lon_min)
        max_lat_cell, max_lon_cell = Overpass.get_grid_cell(lat_max, lon_max)

        overlapping_cells = set()
        for lat_idx in range(min_lat_cell, max_lat_cell + 1):
            for lon_idx in range(min_lon_cell, max_lon_cell + 1):
                overlapping_cells.add((lat_idx, lon_idx))
        
        return overlapping_cells


    @staticmethod
    def combine_cached_data(cached_data_list):
        """
        Combines data from multiple cached responses into a single result.
        """
        combined_data = ET.Element("osm")
        for cached_data in cached_data_list:
            for element in cached_data:
                combined_data.append(element)
        return combined_data
    

def get_base_info(elem: ET.Element, osm_type: OSM_TYPES, with_name=False) :
    """
    Extracts base information (coordinates, OSM ID, and optionally a name) from an OSM element.

    This function retrieves the latitude and longitude coordinates, OSM ID, and optionally the name
    of a given OpenStreetMap (OSM) element. It handles different OSM types (e.g., 'node', 'way') by
    extracting coordinates either directly or from a center tag, depending on the element type.

    Args:
        elem (ET.Element): The XML element representing the OSM entity.
        osm_type (str): The type of the OSM entity (e.g., 'node', 'way'). If 'node', the coordinates
                        are extracted directly from the element; otherwise, from the 'center' tag.
        with_name (bool): Whether to extract and return the name of the element. If True, it attempts
                          to find the 'name' tag within the element and return its value. Defaults to False.

    Returns:
        tuple: A tuple containing:
            - osm_id (str): The OSM ID of the element.
            - coords (tuple): A tuple of (latitude, longitude) coordinates.
            - name (str, optional): The name of the element if `with_name` is True; otherwise, not included.
    """
    # 1. extract coordinates
    if osm_type != 'node' :
        center = elem.find('center')
        lat = float(center.get('lat'))
        lon = float(center.get('lon'))

    else :
        lat = float(elem.get('lat'))
        lon = float(elem.get('lon'))

    coords = tuple((lat, lon))

    # 2. Extract OSM id
    osm_id = elem.get('id')

    # 3. Extract name if specified and return
    if with_name :
        name = elem.find("tag[@k='name']").get('v') if elem.find("tag[@k='name']") is not None else None
        return osm_id, coords, name
    else :
        return osm_id, coords


def fill_cache():

    overpass = Overpass(caching_strategy='XML', cache_dir=OSM_CACHE_DIR)

    with os.scandir(OSM_CACHE_DIR) as it:
        for entry in it:
            if entry.is_file() and entry.name.startswith('hollow_'):
                
                # Read the whole file content as a string
                with open(entry.path, 'r') as f:
                    xml_string = f.read()  

                # Fill the cache with the query and key
                overpass.fill_cache(xml_string)

                # Now delete the file as the cache is filled
                os.remove(entry.path)