import numpy as np import scipy.integrate as spi import logging logger = logging.getLogger(__name__) def ode_setup(particles: np.ndarray, force_function: callable) -> tuple[np.ndarray, callable]: """ Linearizes the ODE system for the particles interacting gravitationally. Returns: - the Y0 array corresponding to the initial conditions (x0 and v0) - the function that computes the right hand side of the ODE with function signature f(t, y) Assumes that the particles array has the following columns: x, y, z, vx, vy, vz, m. """ if particles.shape[1] != 7: raise ValueError("Particles array must have 7 columns: x, y, z, vx, vy, vz, m") # for scipy integrators we need to flatten array which contains 7 columns for now # we don't really care how we reshape as long as we unflatten consistently afterwards particles = particles.reshape(-1, copy=False, order='A') # this is consistent with the unflattening in to_particles()! logger.debug(f"Reshaped 7 columns into {particles.shape=}") def f(y, t): """ Computes the right hand side of the ODE system. The ODE system is linearized around the current positions and velocities. """ y = to_particles(y) # now y has shape (n, 7), with columns x, y, z, vx, vy, vz, m forces = force_function(y[:, [0, 1, 2, -1]]) # compute the accelerations masses = y[:, -1] a = forces / masses[:, None] # the [:, None] is to force broadcasting in order to divide each row of forces by the corresponding mass # a.flatten() # replace some values in y: # the position columns become the velocities # the velocity columns become the accelerations y[:, 0:3] = y[:, 3:6] y[:, 3:6] = a # the masses remain unchanged # flatten the array again y = y.reshape(-1, copy=False, order='A') return y return particles, f def to_particles(y: np.ndarray) -> np.ndarray: """ Converts the 1D array y into a 2D array IN PLACE The new shape is (n, 7) where n is the number of particles. The columns are x, y, z, vx, vy, vz, m """ if y.size % 7 != 0: raise ValueError("The array y should be inflatable to 7 columns") n = y.size // 7 y = y.reshape((n, 7), copy=False, order='F') logger.debug(f"Unflattened array into {y.shape=}") return y