finally accurate mesh forces

This commit is contained in:
Remy Moll 2025-02-02 21:51:56 +01:00
parent 37a2687ffe
commit da8a7d4574
10 changed files with 661 additions and 317 deletions

View File

@ -31,13 +31,13 @@
// Finally - The real content // Finally - The real content
= N-body forces and analytical solutions = N-body forces and analytical solutions
== Objective // == Objective
Implement naive N-body force computation and get an intuition of the challenges: // Implement naive N-body force computation and get an intuition of the challenges:
- accuracy // - accuracy
- computation time // - computation time
- stability // - stability
$==>$ still useful to compute basic quantities of the system, but too limited for large systems or the dynamical evolution of the system // $==>$ still useful to compute basic quantities of the system, but too limited for large systems or the dynamical evolution of the system
== Overview - the system == Overview - the system
@ -47,7 +47,7 @@ Get a feel for the particles and their distribution. [#link(<task1:plot_particle
#columns(2)[ #columns(2)[
#helpers.image_cell(t1, "plot_particle_distribution") #helpers.image_cell(t1, "plot_particle_distribution")
Note: for visibility the outer particles are not shown. // Note: for visibility the outer particles are not shown.
#colbreak() #colbreak()
The system at hand is characterized by: The system at hand is characterized by:
- $N ~ 10^4$ stars - $N ~ 10^4$ stars
@ -81,12 +81,12 @@ We compare the computed density with the analytical model provided by the _Hernq
#helpers.image_cell(t1, "plot_density_distribution") #helpers.image_cell(t1, "plot_density_distribution")
] ]
) )
// Note that by construction, the first shell contains no particles
// => the numerical density is zero there
// Having more bins means to have shells that are nearly empty
// => the error is large, NBINS = 30 is a good compromise
#block(
height: 1fr,
)
== Force computation == Force computation
// N Body and variations // N Body and variations
@ -110,13 +110,48 @@ We compare the computed density with the analytical model provided by the _Hernq
] ]
) )
// basic $N^2$ matches analytical solution without dropoff. but: noisy data from "bad" samples
// $N^2$ with softening matches analytical solution but has a dropoff. No noisy data.
// => softening $\approx 1 \varepsilon$ is a sweet spot since the dropoff is "late"
== Relaxation == Relaxation
Relaxation [#link(<task1:compute_relaxation_time>)[code]]:
// #helpers.code_cell(t1, "compute_relaxation_time") We express system relaxation in terms of the dynamical time of the system.
$
t_"relax" = overbrace(N / (8 log N), n_"relax") dot t_"crossing"
$
where the crossing time of the system can be estimated through the half-mass velocity $t_"crossing" = v(r_"hm")/r_"hm"$.
We find a relaxation of [#link(<task1:compute_relaxation_time>)[code]].
Discussion! // === Discussion
#grid(
columns: (1fr, 1fr),
inset: 0.5em,
block[
#image("relaxation.png")
],
block[
- Each star-star interaction contributes $delta v approx (2 G m )/b$
- Shifting by $epsilon$ *dampens* each contribution
- $=>$ relaxation time increases
]
)
// The estimate for $n_{relax}$ comes from the contribution of each star-star encounter to the velocity dispersion. This depends on the perpendicular force
// $\implies$ a bigger softening length leads to a smaller $\delta v$.
// Using $n_{relax} = \frac{v^2}{\delta v^2}$, and knowing that the value of $v^2$ is derived from the Virial theorem (i.e. unaffected by the softening length), we can see that $n_{relax}$ should increase with $\varepsilon$.
// === Effect
// - The relaxation time **increases** with increasing softening length
// - From the integration over all impact parameters $b$ even $b_{min}$ is chosen to be larger than $\varepsilon$ $\implies$ expect only a small effect on the relaxation time
// **In other words:**
// The softening dampens the change of velocity => time to relax is longer
@ -128,10 +163,13 @@ Discussion!
columns: 2 columns: 2
)[ )[
#helpers.image_cell(t2, "plot_particle_distribution") #helpers.image_cell(t2, "plot_particle_distribution")
$=>$ use $M_"sys" approx 10^4 M_"sol" + M_"BH"$
] ]
== Force computation == Force computation
#helpers.code_reference_cell(t2, "function_mesh_force") #helpers.code_reference_cell(t2, "function_mesh_force")
@ -156,9 +194,16 @@ Discussion!
- very large grids have issues with overdiscretization - very large grids have issues with overdiscretization
$==> 75 times 75 times 75$ as a good compromise $==> 75 times 75 times 75$ as a good compromise
// Some other comments:
// - see the artifacts because of the even grid numbers (hence the switch to 75)
// overdiscretization for large grids -> vertical spread even though r is constant
// this becomes even more apparent when looking at the data without noise - the artifacts remain
] ]
) )
#helpers.image_cell(t2, "plot_force_computation_time")
== Time integration == Time integration
=== Runge-Kutta === Runge-Kutta
#helpers.code_reference_cell(t2, "function_runge_kutta") #helpers.code_reference_cell(t2, "function_runge_kutta")

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -9,6 +9,7 @@ from .units import *
from .forces_basic import * from .forces_basic import *
from .forces_tree import * from .forces_tree import *
from .forces_mesh import * from .forces_mesh import *
from .forces_cache import *
# Helpers for solving the IVP and having time evolution # Helpers for solving the IVP and having time evolution
from .integrate import * from .integrate import *

View File

@ -3,7 +3,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def n_body_forces(particles: np.ndarray, G: float, softening: float = 0): def n_body_forces(particles: np.ndarray, G: float = 1, softening: float = 0):
""" """
Computes the gravitational forces between a set of particles. Computes the gravitational forces between a set of particles.
Assumes that the particles array has the following columns: x, y, z, m. Assumes that the particles array has the following columns: x, y, z, m.

View File

@ -0,0 +1,33 @@
from pathlib import Path
import numpy as np
import timeit
import logging
logger = logging.getLogger(__name__)
def cached_forces(cache_path: Path, particles: np.ndarray, force_function:callable, func_kwargs: dict):
"""
Tries to load the forces from a cache file. If that fails, computes the forces using the provided function.
"""
cache_path.mkdir(parents=True, exist_ok=True)
n_particles = particles.shape[0]
kwargs_str = "_".join([f"{k}_{v}" for k, v in func_kwargs.items()])
force_cache = cache_path / f"forces__{force_function.__name__}__n_{n_particles}__kwargs_{kwargs_str}.npy"
time_cache = cache_path / f"time__{force_function.__name__}__n_{n_particles}__kwargs_{kwargs_str}.npy"
if force_cache.exists() and time_cache.exists():
force = np.load(force_cache)
logger.info(f"Loaded forces from {force_cache}")
time = np.load(time_cache)
else:
force = force_function(particles, **func_kwargs)
np.save(force_cache, force)
time = 0
np.info(f"Timing {force_function.__name__} for {n_particles} particles")
time = timeit.timeit(lambda: force_function(particles, **func_kwargs), number=10)
np.save(time_cache, time)
return force, time

View File

@ -6,8 +6,7 @@ from scipy import fft
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
#### Version 1 - keeping the derivative of phi '''
def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray: def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
""" """
Computes the gravitational force acting on a set of particles using a mesh-based approach. Computes the gravitational force acting on a set of particles using a mesh-based approach.
@ -18,10 +17,11 @@ def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable)
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]") logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]")
mesh, axis = to_mesh(particles, n_grid, mapping) # in this case we create an adaptively sized mesh containing all particles
spacing = np.abs(axis[1] - axis[0]) max_pos = np.max(np.abs(particles[:, :3]))
logger.debug(f"Using mesh spacing: {spacing}") mesh, axis, spacing = create_mesh(-max_pos, max_pos, n_grid)
fill_mesh(particles, mesh, axis, mapping)
# we want a density mesh: # we want a density mesh:
cell_volume = spacing**3 cell_volume = spacing**3
rho = mesh / cell_volume rho = mesh / cell_volume
@ -54,9 +54,8 @@ def mesh_poisson(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray:
Returns the derivative of the potential - grad phi Returns the derivative of the potential - grad phi
""" """
rho_hat = fft.fftn(mesh) rho_hat = fft.fftn(mesh)
k = fft.fftfreq(mesh.shape[0], spacing) k = fft.fftfreq(mesh.shape[0], spacing) * (2 * np.pi)
# shift the zero frequency to the center # shift the zero frequency to the center
k = fft.fftshift(k)
kx, ky, kz = np.meshgrid(k, k, k) kx, ky, kz = np.meshgrid(k, k, k)
k_vec = np.stack([kx, ky, kz], axis=0) k_vec = np.stack([kx, ky, kz], axis=0)
@ -72,13 +71,14 @@ def mesh_poisson(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray:
logger.debug(f"Proceeding to poisson equation with {rho_hat.shape=}, {k_inv.shape=}") logger.debug(f"Proceeding to poisson equation with {rho_hat.shape=}, {k_inv.shape=}")
grad_phi_hat = - 4 * np.pi * G * rho_hat * k_inv * 1j grad_phi_hat = - 4 * np.pi * G * rho_hat * k_inv * 1j
# nabla^2 phi => -i * k * nabla phi = 4 pi G rho => nabla phi = - i * rho * k / k^2 # nabla^2 phi => -i * k * nabla phi = 4 pi G rho => nabla phi = - i * rho * k / k^2
# todo: check minus # TODO: check minus
grad_phi = np.real(fft.ifftn(grad_phi_hat)) grad_phi = np.real(fft.ifftn(grad_phi_hat))
return grad_phi return grad_phi
'''
#### Version 2 - only storing the scalar potential
def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray: def mesh_forces(particles: np.ndarray, G: float = 1, n_grid: int = 50, mapping: callable = None) -> np.ndarray:
""" """
Computes the gravitational force acting on a set of particles using a mesh-based approach. Computes the gravitational force acting on a set of particles using a mesh-based approach.
Assumes that the particles array has the following columns: x, y, z, m. Assumes that the particles array has the following columns: x, y, z, m.
@ -88,126 +88,120 @@ def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callab
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]") logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]")
mesh, axis = to_mesh(particles, n_grid, mapping) # in this case we create an adaptively sized mesh containing all particles
spacing = np.abs(axis[1] - axis[0]) max_pos = np.max(np.abs(particles[:, :3]))
logger.debug(f"Using mesh spacing: {spacing}") mesh, axis, spacing = create_mesh(-max_pos, max_pos, n_grid)
fill_mesh(particles, mesh, axis, mapping)
# we want a density mesh: # we want a density mesh:
cell_volume = spacing**3 cell_volume = spacing**3
rho = mesh / cell_volume rho = mesh / cell_volume
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
show_mesh_information(mesh, "Density mesh") show_mesh_information(mesh, "Density mesh")
# compute the potential and its gradient # compute the potential and its gradient
phi = mesh_poisson_v2(rho, G, spacing) phi = mesh_poisson(rho, G, spacing)
logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}")
phi_grad = np.stack(np.gradient(phi, spacing), axis=0)
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}") logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}")
show_mesh_information(phi, "Potential mesh") show_mesh_information(phi, "Potential")
show_mesh_information(phi_grad[0], "Potential gradient (x-direction)")
# get the acceleration from finite differences of the potential
# a = - grad phi
ax, ay, az = np.gradient(phi, spacing)
a_vec = - np.stack([ax, ay, az], axis=0)
# compute the particle forces from the mesh potential # compute the particle forces from the mesh potential
forces = np.zeros_like(particles[:, :3]) forces = np.zeros_like(particles[:, :3])
for i, p in enumerate(particles): ijks = np.digitize(particles[:, :3], axis) - 1
ijk = np.digitize(p, axis) - 1
# this gives 4 entries since p[3] the mass is digitized as well -> this is meaningless and we discard it for i in range(particles.shape[0]):
# logger.debug(f"Particle {p} maps to cell {ijk}") m = particles[i, 3]
forces[i] = - p[3] * phi_grad[..., ijk[0], ijk[1], ijk[2]] idx = ijks[i]
# TODO remove factor of 10 # f = m * a
# TODO could also index phi_grad the other way around? forces[i] = m * a_vec[..., idx[0], idx[1], idx[2]]
return forces return forces
def mesh_poisson_v2(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray: def mesh_poisson(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray:
""" """
Solves the poisson equation for the mesh using the FFT. Solves the poisson equation for the mesh using the FFT.
Returns the scalar potential. Returns the the potential - grad
""" """
rho_hat = fft.fftn(mesh) rho_hat = fft.fftn(mesh)
k = fft.fftfreq(mesh.shape[0], spacing)
# shift the zero frequency to the center
k = fft.fftshift(k)
# we also need the wave numbers
spacing_3d = np.linalg.norm([spacing, spacing, spacing])
k = fft.fftfreq(mesh.shape[0], spacing) * (2 * np.pi)
# TODO: check if this is correct
# assuming the grid is cubic
kx, ky, kz = np.meshgrid(k, k, k) kx, ky, kz = np.meshgrid(k, k, k)
k_sr = kx**2 + ky**2 + kz**2 k_sr = kx**2 + ky**2 + kz**2
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}") logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}")
logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}") logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}")
show_mesh_information(np.abs(k_sr), "k_square") show_mesh_information(np.abs(k_sr), "k_square")
# avoid division by zero
# TODO: review this
k_sr[k_sr == 0] = np.inf k_sr[k_sr == 0] = np.inf
logger.debug(f"Proceeding to poisson equation with {rho_hat.shape=}, {k_sr.shape=}") # k_inv = k_vec / k_sr # allows for element-wise division
# logger.debug(f"Proceeding to poisson equation with {rho_hat.shape=}, {k_inv.shape=}")
phi_hat = - 4 * np.pi * G * rho_hat / k_sr phi_hat = - 4 * np.pi * G * rho_hat / k_sr
# - comes from i squared # nabla^2 phi becomes -i * k * nabla phi_hat = 4 pi G rho_hat
# TODO: 4pi stays since the backtransform removes the 1/2pi factor # => nabla phi = - i * rho * k / k^2
phi = np.real(fft.ifftn(phi_hat)) phi = np.real(fft.ifftn(phi_hat))
return phi return phi
#### Helper functions for star mapping #### Helper functions for star mapping
def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.ndarray, np.ndarray]: def create_mesh(min_pos: float, max_pos: float, n_grid: int) -> tuple[np.ndarray, np.ndarray, float]:
""" """
Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid. Creates an empty 3D mesh with the given dimensions.
Returns the mesh, the axis and the spacing between the cells.
"""
axis = np.linspace(min_pos, max_pos, n_grid)
mesh = np.zeros((n_grid, n_grid, n_grid))
spacing = np.diff(axis)[0]
logger.debug(f"Using mesh spacing: {spacing}")
return mesh, axis, spacing
def fill_mesh(particles: np.ndarray, mesh: np.ndarray, axis: np.ndarray, mapping: callable):
"""
Maps a list of particles to a the mesh (in place)
Assumes that the particles array has the following columns: x, y, z, ..., m. Assumes that the particles array has the following columns: x, y, z, ..., m.
Uses the mapping function to detemine the contribution to each cell. Uses the mapping function to detemine the contribution to each cell. The mapped density should be normalized to 1.
""" """
if particles.shape[1] < 4: if particles.shape[1] < 4:
raise ValueError("Particles array must have at least 4 columns: x, y, z, m") raise ValueError("Particles array must have at least 4 columns: x, y, z, ..., m")
# axis provide an easy way to map the particles to the mesh # each particle will have its particular contirbution (determined through a weight function, mapping)
max_pos = np.max(particles[:, :3]) for i in range(particles.shape[0]):
axis = np.linspace(-max_pos, max_pos, n_grid) p = particles[i]
mesh_grid = np.meshgrid(axis, axis, axis) mapping(mesh, p, axis) # this directly adds to the mesh
mesh = np.zeros_like(mesh_grid[0])
for p in particles:
m = p[-1]
# spread the star onto cells through the shape function, taking into account the mass
ijks, weights = mapping(p, axis)
for ijk, weight in zip(ijks, weights):
mesh[ijk[0], ijk[1], ijk[2]] += weight * m
return mesh, axis
def particle_to_cells_nn(particle, axis): def particle_mapping_nn(mesh_to_fill: np.ndarray, particle: np.ndarray, axis: np.ndarray):
# find the single cell that contains the particle # fills the mesh in place with the particle mass
ijk = np.digitize(particle, axis) - 1 ijk = np.digitize(particle, axis) - 1
# the weight is obviously 1 mesh_to_fill[ijk[0], ijk[1], ijk[2]] += particle[3]
return [ijk], [1]
bbox = np.array([
[1, 0, 0],
[-1, 0, 0],
[1, 1, 0],
[-1, -1, 0],
[1, 1, 1],
[-1, -1, 1],
[1, 1, -1],
[-1, -1, -1]
])
def particle_to_cells_cic(particle, axis, width): def particle_mapping_cic(mesh_to_fill: np.ndarray, particle: np.ndarray, axis: np.ndarray):
# create a virtual cell around the particle and check the intersections # fills the mesh in place with the particle mass
bounding_cell = particle + width * bbox ijk = np.digitize(particle, axis) - 1
spacing = axis[1] - axis[0]
# find all the cells that intersect with the virtual cell # generate a 3D map of all the distances to the particle
ijks = [] px, py, pz = np.meshgrid(axis, axis, axis, indexing='ij')
weights = [] dist = np.linalg.norm([px - particle[0], py - particle[1], pz - particle[2]], axis=0)
for b in bounding_cell:
# TODO: this is not the correct weight
w = np.linalg.norm(particle - b)
ijk = np.digitize(b, axis) - 1
ijks.append(ijk)
weights.append(w)
# ensure that the weights sum to 1 # the weights are the inverse of the distance, cut off at the cell size
weights = np.array(weights) weights = np.maximum(0, 1 - dist / spacing)
weights /= np.sum(weights) mesh_to_fill += particle[3] * weights
return ijks, weights

View File

@ -163,6 +163,11 @@ def particles_plot_3d(positions: np.ndarray, masses: np.ndarray, title: str = "P
fig = plt.figure() fig = plt.figure()
fig.suptitle(title) fig.suptitle(title)
ax = fig.add_subplot(111, projection='3d') ax = fig.add_subplot(111, projection='3d')
if np.all(masses == masses[0]):
sc = ax.scatter(x, y, z, c='blue')
cbar = plt.colorbar(sc, ax=ax, pad=0.1)
else:
sc = ax.scatter(x, y, z, cmap='coolwarm', c=masses) sc = ax.scatter(x, y, z, cmap='coolwarm', c=masses)
cbar = plt.colorbar(sc, ax=ax, pad=0.1) cbar = plt.colorbar(sc, ax=ax, pad=0.1)

View File

@ -1,6 +1,6 @@
## Implementation of a mesh based full solver with boundary conditions etc. ## Implementation of a mesh based full solver with boundary conditions etc.
import numpy as np import numpy as np
from . import mesh_forces from . import forces_mesh
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,30 +36,38 @@ def mesh_solver(
logger.debug(f"Using mesh spacing: {spacing}") logger.debug(f"Using mesh spacing: {spacing}")
# Check that the boundary condition is fullfilled # # Check that the boundary condition is fullfilled
if boundary == "periodic": # if boundary == "periodic":
raise NotImplementedError("Periodic boundary conditions are not implemented yet") # raise NotImplementedError("Periodic boundary conditions are not implemented yet")
elif boundary == "vanishing": # elif boundary == "vanishing":
# remove the particles that are outside the mesh # # remove the particles that are outside the mesh
outlier_mask = particles[:, :3] < bounds[0] | particles[:, :3] > bounds[1] # outlier_mask = np.logical_or(particles[:, :3] < bounds[0], particles[:, :3] > bounds[1])
# if np.any(outlier_mask):
# idx = np.any(outlier_mask, axis=1)
# logger.info(f"{idx.shape=}")
# logger.warning(f"Removing {np.sum(idx)} particles that left the mesh")
# # replace the particles by nan values
# particles[idx, :] = np.nan
# print(np.sum(np.isnan(particles)))
# else:
# raise ValueError(f"Unknown boundary condition: {boundary}")
if np.any(outlier_mask):
logger.warning(f"Removing {np.sum(outlier_mask)} particles that are outside the mesh")
particles = particles[~outlier_mask]
logger.debug(f"New particles shape: {particles.shape}")
else:
raise ValueError(f"Unknown boundary condition: {boundary}")
# fill the mesh
particles_to_mesh(particles, mesh, axis, mapping)
# we want a density mesh:
cell_volume = spacing**3
rho = mesh / cell_volume
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
mesh_forces.show_mesh_information(mesh, "Density mesh") forces_mesh.show_mesh_information(mesh, "Density mesh")
# compute the potential and its gradient # compute the potential and its gradient
phi_grad = mesh_forces.mesh_poisson(rho, G, spacing) phi_grad = forces_mesh.mesh_poisson(rho, G, spacing)
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}") logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
mesh_forces.show_mesh_information(phi_grad[0], "Potential gradient (x-direction)") forces_mesh.show_mesh_information(phi_grad[0], "Potential gradient (x-direction)")
# compute the particle forces from the mesh potential # compute the particle forces from the mesh potential
forces = np.zeros_like(particles[:, :3]) forces = np.zeros_like(particles[:, :3])
@ -87,3 +95,56 @@ def particles_to_mesh(particles: np.ndarray, mesh: np.ndarray, axis: np.ndarray,
ijks, weights = mapping(p, axis) ijks, weights = mapping(p, axis)
for ijk, weight in zip(ijks, weights): for ijk, weight in zip(ijks, weights):
mesh[ijk[0], ijk[1], ijk[2]] += weight * m mesh[ijk[0], ijk[1], ijk[2]] += weight * m
'''
#### Actually need to patch this
def ode_setup(particles: np.ndarray, force_function: callable) -> tuple[np.ndarray, callable]:
"""
Linearizes the ODE system for the particles interacting gravitationally.
Returns:
- the Y0 array corresponding to the initial conditions (x0 and v0)
- the function that computes the right hand side of the ODE with function signature f(t, y)
Assumes that the particles array has the following columns: x, y, z, vx, vy, vz, m.
"""
if particles.shape[1] != 7:
raise ValueError("Particles array must have 7 columns: x, y, z, vx, vy, vz, m")
# for the integrators we need to flatten array which contains 7 columns for now
# we don't really care how we reshape as long as we unflatten consistently
particles = particles.flatten()
logger.debug(f"Reshaped 7 columns into {particles.shape=}")
def f(y, t):
"""
Computes the right hand side of the ODE system.
The ODE system is linearized around the current positions and velocities.
"""
p = to_particles(y)
# this is explicitly a copy, which has shape (n, 7)
# columns x, y, z, vx, vy, vz, m
# (need to keep y intact since integrators make multiple function calls)
forces = force_function(p[:, [0, 1, 2, -1]])
# compute the accelerations
masses = p[:, -1]
a = forces / masses[:, None]
# the [:, None] is to force broadcasting in order to divide each row of forces by the corresponding mass
# the position columns become the velocities
# the velocity columns become the accelerations
p[:, 0:3] = p[:, 3:6]
p[:, 3:6] = a
# the masses remain unchanged
# p[:, -1] = p[:, -1]
# flatten the array again
# logger.debug(f"As particles: {y}")
p = p.reshape(-1, copy=False)
# logger.debug(f"As column: {y}")
return p
return particles, f
'''