import numpy as np import scipy.integrate as spi import logging logger = logging.getLogger(__name__) def ode_setup(particles: np.ndarray, force_function: callable) -> tuple[np.ndarray, callable]: """ Linearizes the ODE system for the particles interacting gravitationally. Returns: - the Y0 array corresponding to the initial conditions (x0 and v0) - the function that computes the right hand side of the ODE with function signature f(t, y) Assumes that the particles array has the following columns: x, y, z, vx, vy, vz, m. """ if particles.shape[1] != 7: raise ValueError("Particles array must have 7 columns: x, y, z, vx, vy, vz, m") n = particles.shape[0] # for scipy integrators we need to flatten the n 3D positions and n 3D velocities y0 = np.zeros(6*n) y0[:3*n] = particles[:, :3].flatten() y0[3*n:] = particles[:, 3:6].flatten() # the masses don't change we can define them once masses = particles[:, 6] logger.debug(f"Reshaped {particles.shape} to y0 with {y0.shape} and masses with {masses.shape}") def f(y, t): """ Computes the right hand side of the ODE system. The ODE system is linearized around the current positions and velocities. """ n = y.size // 6 logger.debug(f"y with shape {y.shape}") # unsqueeze and unstack to extract the positions and velocities y = y.reshape((2*n, 3)) x = y[:n, ...] v = y[n:, ...] logger.debug(f"Unstacked y into x with shape {x.shape} and v with shape {v.shape}") # compute the forces x_with_m = np.zeros((n, 4)) x_with_m[:, :3] = x x_with_m[:, 3] = masses forces = force_function(x_with_m) # compute the accelerations a = forces / masses[:, None] a.flatten() # the [:, None] is to force broadcasting in order to divide each row of forces by the corresponding mass # reshape into a 1D array return np.vstack((v, a)).flatten() return y0, f def to_particles(y: np.ndarray) -> np.ndarray: """ Converts the 1D array y into a 2D array with the shape (n, 6) where n is the number of particles. The columns are x, y, z, vx, vy, vz """ n = y.size // 6 y = y.reshape((2*n, 3)) x = y[:n, ...] v = y[n:, ...] return np.hstack((x, v))