import numpy as np
import logging
logger = logging.getLogger(__name__)


def density_distribution(r_bins: np.ndarray, particles: np.ndarray, ret_error: bool = False):
    """
    Computes the radial density distribution of a set of particles.
    Assumes that the particles array has the following columns: x, y, z, m.
    """
    if particles.shape[1] != 4:
        raise ValueError("Particles array must have 4 columns: x, y, z, m")

    m = particles[:, 3]
    r = np.linalg.norm(particles[:, :3], axis=1)
    density = [np.sum(m[(r >= r_bins[i]) & (r < r_bins[i + 1])]) for i in range(len(r_bins) - 1)]

    # add the first volume which  should be wrt 0
    volume = 4/3 * np.pi * (r_bins[1:]**3 - r_bins[:-1]**3)
    volume = np.insert(volume, 0, 4/3 * np.pi * r_bins[0]**3)
    density = r_bins / volume
    if ret_error:
        return density, density / np.sqrt(r_bins)
    else:
        return density



def r_distribution(particles: np.ndarray):
    """
    Computes the distribution of distances (to the origin) of a set of particles.
    Assumes that the particles array has the following columns: x, y, z ...
    """
    if particles.shape[1] < 3:
        raise ValueError("Particles array must have at least 3 columns: x, y, z")

    r = np.linalg.norm(particles[:, :3], axis=1)
    return r



def remove_outliers(particles: np.ndarray, std_threshold: float = 3):
    """
    Removes outliers from a set of particles.
    Assumes that the particles array has the following columns: x, y, z ...
    """
    if particles.shape[1] < 3:
        raise ValueError("Particles array must have at least 3 columns: x, y, z")

    r = np.linalg.norm(particles[:, :3], axis=1)
    r_std = np.std(r)
    r_mean = np.mean(r)
    mask = np.abs(r - r_mean) < std_threshold * r_std
    return particles[mask]



def mean_interparticle_distance(particles: np.ndarray):
    """
    Computes the mean interparticle distance of a set of particles.
    Assumes that the particles array has the following columns: x, y, z ...
    """
    if particles.shape[1] < 3:
        raise ValueError("Particles array must have at least 3 columns: x, y, z")
    

    r_half_mass = half_mass_radius(particles)
    r = np.linalg.norm(particles[:, :3], axis=1)

    n_half_mass = np.sum(r < r_half_mass)
    logger.debug(f"Number of particles within half mass radius: {n_half_mass} of {particles.shape[0]}")

    rho = n_half_mass / (4/3 * np.pi * r_half_mass**3)
    # the mean distance between particles is the inverse of the density
    return (1 / rho)**(1/3)
    # TODO: check if this is correct



def half_mass_radius(particles: np.ndarray):
    """
    Computes the half mass radius of a set of particles.
    Assumes that the particles array has the following columns: x, y, z ...
    """
    if particles.shape[1] < 3:
        raise ValueError("Particles array must have at least 3 columns: x, y, z")
    
    # even though in the simple example, all the masses are the same, we will consider the general case 
    total_mass = np.sum(particles[:, 3])
    half_mass = total_mass / 2
    
    # sort the particles by distance
    r = np.linalg.norm(particles[:, :3], axis=1)
    indices = np.argsort(r)
    r = r[indices]
    masses = particles[indices, 3]
    masses_cumsum = np.cumsum(masses)

    i = np.argmin(np.abs(masses_cumsum - half_mass))
    logger.debug(f"Half mass radius: {r[i]} for {i}th particle of {particles.shape[0]}")
    r_hm = r[i]

    return r_hm



def relaxation_timescale(particles: np.ndarray, G:float) -> float:
    """
    Computes the relaxation timescale of a set of particles using the velocity at the half mass radius.
    Assumes that the particles array has the following columns: x, y, z ...
    """
    m_half = np.sum(particles[:, 3]) / 2 # enclosed mass at half mass radius
    r_half = half_mass_radius(particles)
    n_half = np.sum(np.linalg.norm(particles[:, :3], axis=1) < r_half) # number of enclosed particles
    v_c = np.sqrt(G * m_half / r_half)

    # the crossing time for the half mass system is
    t_c = r_half / v_c
    logger.debug(f"Crossing time for half mass system: {t_c}")

    # the relaxation timescale is t_c * N/(10 * log(N))
    t_rel = t_c * n_half / (10 * np.ln(n_half))

    return t_rel