69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
import numpy as np
|
|
import scipy.integrate as spi
|
|
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def ode_setup(particles: np.ndarray, force_function: callable) -> tuple[np.ndarray, callable]:
|
|
"""
|
|
Linearizes the ODE system for the particles interacting gravitationally.
|
|
Returns:
|
|
- the Y0 array corresponding to the initial conditions (x0 and v0)
|
|
- the function that computes the right hand side of the ODE with function signature f(t, y)
|
|
Assumes that the particles array has the following columns: x, y, z, vx, vy, vz, m.
|
|
"""
|
|
if particles.shape[1] != 7:
|
|
raise ValueError("Particles array must have 7 columns: x, y, z, vx, vy, vz, m")
|
|
|
|
n = particles.shape[0]
|
|
# for scipy integrators we need to flatten the n 3D positions and n 3D velocities
|
|
y0 = np.zeros(6*n)
|
|
y0[:3*n] = particles[:, :3].flatten()
|
|
y0[3*n:] = particles[:, 3:6].flatten()
|
|
|
|
# the masses don't change we can define them once
|
|
masses = particles[:, 6]
|
|
logger.debug(f"Reshaped {particles.shape} to y0 with {y0.shape} and masses with {masses.shape}")
|
|
|
|
|
|
def f(y, t):
|
|
"""
|
|
Computes the right hand side of the ODE system.
|
|
The ODE system is linearized around the current positions and velocities.
|
|
"""
|
|
n = y.size // 6
|
|
logger.debug(f"y with shape {y.shape}")
|
|
# unsqueeze and unstack to extract the positions and velocities
|
|
y = y.reshape((2*n, 3))
|
|
x = y[:n, ...]
|
|
v = y[n:, ...]
|
|
logger.debug(f"Unstacked y into x with shape {x.shape} and v with shape {v.shape}")
|
|
|
|
# compute the forces
|
|
x_with_m = np.zeros((n, 4))
|
|
x_with_m[:, :3] = x
|
|
x_with_m[:, 3] = masses
|
|
forces = force_function(x_with_m)
|
|
|
|
# compute the accelerations
|
|
a = forces / masses[:, None]
|
|
a.flatten()
|
|
# the [:, None] is to force broadcasting in order to divide each row of forces by the corresponding mass
|
|
|
|
# reshape into a 1D array
|
|
return np.vstack((v, a)).flatten()
|
|
|
|
return y0, f
|
|
|
|
|
|
def to_particles(y: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Converts the 1D array y into a 2D array with the shape (n, 6) where n is the number of particles.
|
|
The columns are x, y, z, vx, vy, vz
|
|
"""
|
|
n = y.size // 6
|
|
y = y.reshape((2*n, 3))
|
|
x = y[:n, ...]
|
|
v = y[n:, ...]
|
|
return np.hstack((x, v)) |