248 lines
9.0 KiB
Python

## Implementation of a mesh based force solver
import numpy as np
import matplotlib.pyplot as plt
from scipy import fft
import logging
logger = logging.getLogger(__name__)
#### Version 1 - keeping the derivative of phi
def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
"""
Computes the gravitational force acting on a set of particles using a mesh-based approach.
Assumes that the particles array has the following columns: x, y, z, m.
"""
if particles.shape[1] != 4:
raise ValueError("Particles array must have 4 columns: x, y, z, m")
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]")
mesh, axis = to_mesh(particles, n_grid, mapping)
spacing = np.abs(axis[1] - axis[0])
logger.debug(f"Using mesh spacing: {spacing}")
# we want a density mesh:
cell_volume = spacing**3
rho = mesh / cell_volume
if logger.isEnabledFor(logging.DEBUG):
show_mesh_information(mesh, "Density mesh")
# compute the potential and its gradient
phi_grad = mesh_poisson(rho, G, spacing)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
show_mesh_information(phi_grad[0], "Potential gradient (x-direction)")
# compute the particle forces from the mesh potential
forces = np.zeros_like(particles[:, :3])
for i, p in enumerate(particles):
ijk = np.digitize(p, axis) - 1
logger.debug(f"Particle {p} maps to cell {ijk}")
# this gives 4 entries since p[3] the mass is digitized as well -> this is meaningless and we discard it
# logger.debug(f"Particle {p} maps to cell {ijk}")
forces[i] = - p[3] * phi_grad[..., ijk[0], ijk[1], ijk[2]]
return forces
def mesh_poisson(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray:
"""
Solves the poisson equation for the mesh using the FFT.
Returns the derivative of the potential - grad phi
"""
rho_hat = fft.fftn(mesh)
k = fft.fftfreq(mesh.shape[0], spacing)
# shift the zero frequency to the center
k = fft.fftshift(k)
kx, ky, kz = np.meshgrid(k, k, k)
k_vec = np.stack([kx, ky, kz], axis=0)
k_sr = kx**2 + ky**2 + kz**2
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}")
logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}")
show_mesh_information(np.abs(k_sr), "k_square")
k_sr[k_sr == 0] = np.inf
k_inv = k_vec / k_sr # allows for element-wise division
logger.debug(f"Proceeding to poisson equation with {rho_hat.shape=}, {k_inv.shape=}")
grad_phi_hat = - 4 * np.pi * G * rho_hat * k_inv * 1j
# nabla^2 phi => -i * k * nabla phi = 4 pi G rho => nabla phi = - i * rho * k / k^2
# todo: check minus
grad_phi = np.real(fft.ifftn(grad_phi_hat))
return grad_phi
#### Version 2 - only storing the scalar potential
def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
"""
Computes the gravitational force acting on a set of particles using a mesh-based approach.
Assumes that the particles array has the following columns: x, y, z, m.
"""
if particles.shape[1] != 4:
raise ValueError("Particles array must have 4 columns: x, y, z, m")
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]")
mesh, axis = to_mesh(particles, n_grid, mapping)
spacing = np.abs(axis[1] - axis[0])
logger.debug(f"Using mesh spacing: {spacing}")
# we want a density mesh:
cell_volume = spacing**3
rho = mesh / cell_volume
if logger.isEnabledFor(logging.DEBUG):
show_mesh_information(mesh, "Density mesh")
# compute the potential and its gradient
phi = mesh_poisson_v2(rho, G, spacing)
logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}")
phi_grad = np.stack(np.gradient(phi, spacing), axis=0)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
show_mesh_information(phi, "Potential mesh")
show_mesh_information(phi_grad[0], "Potential gradient (x-direction)")
# compute the particle forces from the mesh potential
forces = np.zeros_like(particles[:, :3])
for i, p in enumerate(particles):
ijk = np.digitize(p, axis) - 1
# this gives 4 entries since p[3] the mass is digitized as well -> this is meaningless and we discard it
# logger.debug(f"Particle {p} maps to cell {ijk}")
forces[i] = - p[3] * phi_grad[..., ijk[0], ijk[1], ijk[2]]
# TODO remove factor of 10
# TODO could also index phi_grad the other way around?
return forces
def mesh_poisson_v2(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray:
"""
Solves the poisson equation for the mesh using the FFT.
Returns the scalar potential.
"""
rho_hat = fft.fftn(mesh)
k = fft.fftfreq(mesh.shape[0], spacing)
# shift the zero frequency to the center
k = fft.fftshift(k)
kx, ky, kz = np.meshgrid(k, k, k)
k_sr = kx**2 + ky**2 + kz**2
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}")
logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}")
show_mesh_information(np.abs(k_sr), "k_square")
# avoid division by zero
# TODO: review this
k_sr[k_sr == 0] = np.inf
logger.debug(f"Proceeding to poisson equation with {rho_hat.shape=}, {k_sr.shape=}")
phi_hat = - 4 * np.pi * G * rho_hat / k_sr
# - comes from i squared
# TODO: 4pi stays since the backtransform removes the 1/2pi factor
phi = np.real(fft.ifftn(phi_hat))
return phi
#### Helper functions for star mapping
def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.ndarray, np.ndarray]:
"""
Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid.
Assumes that the particles array has the following columns: x, y, z, ..., m.
Uses the mapping function to detemine the contribution to each cell.
"""
if particles.shape[1] < 4:
raise ValueError("Particles array must have at least 4 columns: x, y, z, m")
# axis provide an easy way to map the particles to the mesh
max_pos = np.max(particles[:, :3])
axis = np.linspace(-max_pos, max_pos, n_grid)
mesh_grid = np.meshgrid(axis, axis, axis)
mesh = np.zeros_like(mesh_grid[0])
for p in particles:
m = p[-1]
# spread the star onto cells through the shape function, taking into account the mass
ijks, weights = mapping(p, axis)
for ijk, weight in zip(ijks, weights):
mesh[ijk[0], ijk[1], ijk[2]] += weight * m
return mesh, axis
def particle_to_cells_nn(particle, axis):
# find the single cell that contains the particle
ijk = np.digitize(particle, axis) - 1
# the weight is obviously 1
return [ijk], [1]
bbox = np.array([
[1, 0, 0],
[-1, 0, 0],
[1, 1, 0],
[-1, -1, 0],
[1, 1, 1],
[-1, -1, 1],
[1, 1, -1],
[-1, -1, -1]
])
def particle_to_cells_cic(particle, axis, width):
# create a virtual cell around the particle and check the intersections
bounding_cell = particle + width * bbox
# find all the cells that intersect with the virtual cell
ijks = []
weights = []
for b in bounding_cell:
# TODO: this is not the correct weight
w = np.linalg.norm(particle - b)
ijk = np.digitize(b, axis) - 1
ijks.append(ijk)
weights.append(w)
# ensure that the weights sum to 1
weights = np.array(weights)
weights /= np.sum(weights)
return ijks, weights
#### Helper functions for mesh plotting
def show_mesh_information(mesh: np.ndarray, name: str):
logger.info(f"Mesh information for {name}")
logger.info(f"Total mapped mass: {np.sum(mesh):.0f}")
logger.info(f"Max cell value: {np.max(mesh)}")
logger.info(f"Min cell value: {np.min(mesh)}")
logger.info(f"Mean cell value: {np.mean(mesh)}")
mesh_plot_3d(mesh, name)
mesh_plot_2d(mesh, name)
def mesh_plot_3d(mesh: np.ndarray, name: str):
fig = plt.figure()
fig.suptitle(f"{name} - {mesh.shape}")
ax = fig.add_subplot(111, projection='3d')
sc = ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis')
plt.colorbar(sc, ax=ax, label='Density')
plt.show()
def mesh_plot_2d(mesh: np.ndarray, name: str, only_z: bool = False):
fig = plt.figure()
fig.suptitle(f"{name} - {mesh.shape}")
if only_z:
plt.imshow(np.sum(mesh, axis=2), cmap='viridis', origin='lower')
else:
axs = fig.subplots(1, 3)
axs[0].imshow(np.sum(mesh, axis=0), origin='lower')
axs[0].set_title("Flattened in x")
axs[1].imshow(np.sum(mesh, axis=1), origin='lower')
axs[1].set_title("Flattened in y")
axs[2].imshow(np.sum(mesh, axis=2), origin='lower')
axs[2].set_title("Flattened in z")
plt.show()