68 lines
2.4 KiB
Python

import numpy as np
import scipy.integrate as spi
import logging
logger = logging.getLogger(__name__)
def ode_setup(particles: np.ndarray, force_function: callable) -> tuple[np.ndarray, callable]:
"""
Linearizes the ODE system for the particles interacting gravitationally.
Returns:
- the Y0 array corresponding to the initial conditions (x0 and v0)
- the function that computes the right hand side of the ODE with function signature f(t, y)
Assumes that the particles array has the following columns: x, y, z, vx, vy, vz, m.
"""
if particles.shape[1] != 7:
raise ValueError("Particles array must have 7 columns: x, y, z, vx, vy, vz, m")
# for scipy integrators we need to flatten array which contains 7 columns for now
# we don't really care how we reshape as long as we unflatten consistently afterwards
particles = particles.reshape(-1, copy=False, order='A')
# this is consistent with the unflattening in to_particles()!
logger.debug(f"Reshaped 7 columns into {particles.shape=}")
def f(y, t):
"""
Computes the right hand side of the ODE system.
The ODE system is linearized around the current positions and velocities.
"""
y = to_particles(y)
# now y has shape (n, 7), with columns x, y, z, vx, vy, vz, m
forces = force_function(y[:, [0, 1, 2, -1]])
# compute the accelerations
masses = y[:, -1]
a = forces / masses[:, None]
# the [:, None] is to force broadcasting in order to divide each row of forces by the corresponding mass
# a.flatten()
# replace some values in y:
# the position columns become the velocities
# the velocity columns become the accelerations
y[:, 0:3] = y[:, 3:6]
y[:, 3:6] = a
# the masses remain unchanged
# flatten the array again
y = y.reshape(-1, copy=False, order='A')
return y
return particles, f
def to_particles(y: np.ndarray) -> np.ndarray:
"""
Converts the 1D array y into a 2D array IN PLACE
The new shape is (n, 7) where n is the number of particles.
The columns are x, y, z, vx, vy, vz, m
"""
if y.size % 7 != 0:
raise ValueError("The array y should be inflatable to 7 columns")
n = y.size // 7
y = y.reshape((n, 7), copy=False, order='F')
logger.debug(f"Unflattened array into {y.shape=}")
return y