import astropy.units as u
import numpy as np
import logging
logger = logging.getLogger(__name__)


M_SCALE: int = None
R_SCALE: int = None

def seed_scales(r_scale: u.Quantity, m_scale: u.Quantity):
    """
    Set the scales for the given simulation.
    Parameters:
    - r_scale: astropy.units.Quantity with units of length - the characteristic length scale of the simulation. Particle positions are expressed in units of this scale.
    - m_scale: astropy.units.Quantity with units of mass - the characteristic mass scale of the simulation. Particle masses are expressed in units of this scale.
    """
    global M_SCALE, R_SCALE
    M_SCALE = m_scale
    R_SCALE = r_scale
    logger.info(f"Set scales: M_SCALE = {M_SCALE:.2g}, R_SCALE = {R_SCALE:.2g}")


def apply_units(columns: np.array, quantity: str):
    if quantity == "mass":
        return columns * M_SCALE
    elif quantity == "position":
        return columns * R_SCALE
    elif quantity == "volume":
        return columns * R_SCALE**3
    elif quantity == "density":
        return columns * M_SCALE / R_SCALE**3
    
    ## Derived quantities
    elif quantity == "force":
        # using F = GMm/R^2 => F = M_SCALE**2 / R_SCALE**2 (G = 1)
        return columns * M_SCALE**2 / R_SCALE**2
    elif quantity == "velocity":
        # using the Virial theorem: v^2 = GM/R => v = sqrt(GM/R) => v = sqrt(M_SCALE / R_SCALE) (G = 1)
        return columns * np.sqrt(M_SCALE / R_SCALE)
    elif quantity == "time":
        # using the dynamical time: t_dyn = 1/sqrt(G*rho) => t_dyn = sqrt(4/3 * pi * R_SCALE**3 / M_SCALE) (G = 1)
        return columns * np.sqrt(4/3 * np.pi * R_SCALE**3 / M_SCALE)
    else:
        raise ValueError(f"Unknown quantity: {quantity}")