import numpy as np
import scipy.integrate as spi

import logging
logger = logging.getLogger(__name__)


def ode_setup(particles: np.ndarray, force_function: callable) -> tuple[np.ndarray, callable]:
    """
    Linearizes the ODE system for the particles interacting gravitationally.
    Returns:
        - the Y0 array corresponding to the initial conditions (x0 and v0)
        - the function that computes the right hand side of the ODE with function signature f(t, y)
    Assumes that the particles array has the following columns: x, y, z, vx, vy, vz, m.
    """
    if particles.shape[1] != 7:
        raise ValueError("Particles array must have 7 columns: x, y, z, vx, vy, vz, m")

    # for the integrators we need to flatten array which contains 7 columns for now
    # we don't really care how we reshape as long as we unflatten consistently
    particles = particles.flatten()
    logger.debug(f"Reshaped 7 columns into {particles.shape=}")

    def f(y, t):
        """
        Computes the right hand side of the ODE system.
        The ODE system is linearized around the current positions and velocities.
        """
        p = to_particles(y)
        # this is explicitly a copy, which has shape (n, 7)
        # columns x, y, z, vx, vy, vz, m
        # (need to keep y intact since integrators make multiple function calls)

        forces = force_function(p[:, [0, 1, 2, -1]])

        # compute the accelerations
        masses = p[:, -1]
        a = forces / masses[:, None]
        # the [:, None] is to force broadcasting in order to divide each row of forces by the corresponding mass

        # the position columns become the velocities
        # the velocity columns become the accelerations
        p[:, 0:3] = p[:, 3:6]
        p[:, 3:6] = a
        # the masses remain unchanged
        # p[:, -1] = p[:, -1]

        # flatten the array again
        # logger.debug(f"As particles: {y}")
        p = p.reshape(-1, copy=False)
        
        # logger.debug(f"As column: {y}")
        return p

    return particles, f


def to_particles(y: np.ndarray) -> np.ndarray:
    """
    Converts the 1D array y into a 2D array, by creating a copy
    The new shape is (n, 7) where n is the number of particles.
    The columns are x, y, z, vx, vy, vz, m
    """
    if y.size % 7 != 0:
        raise ValueError("The array y should be inflatable to 7 columns")

    y = y.reshape((-1, 7), copy=True)
    # logger.debug(f"Unflattened array into {y.shape=}")
    return y


def to_particles_3d(y: np.ndarray) -> np.ndarray:
    """
    Converts the 2D sol array with one vector per timestep into a 3D array:
    2d particles (nx7) x nsteps
    """
    n_steps = y.shape[0]
    n_particles = y.shape[1] // 7
    y = y.reshape((n_steps, n_particles, 7))
    
    logger.info(f"Unflattened array into {y.shape=}")
    return y


def runge_kutta_4(y: np.ndarray, t: float, f: callable, dt: float):
    """
    Runge-Kutta 4th order integrator.
    """
    k1 = f(y, t)
    k2 = f(y + k1/2 * dt, t + dt/2)
    k3 = f(y + k2/2 * dt, t + dt/2)
    k4 = f(y + k3 * dt, t + dt)
    return y + (k1 + 2*k2 + 2*k3 + k4)/6 * dt