mesh solver finally with sensible results

This commit is contained in:
Remy Moll 2025-01-22 22:05:29 +01:00
parent 65b893b41a
commit 9d54e9743e
6 changed files with 341 additions and 714 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,3 +12,7 @@ from .forces_mesh import *
# Helpers for solving the IVP and having time evolution
from .integrate import *
# Fully mesh-based solver
from .solver_mesh import *

View File

@ -75,6 +75,7 @@ def analytical_forces(particles: np.ndarray):
The force on a particle at radius r is simply the force exerted by a point mass with the enclosed mass.
Assumes that the particles array has the following columns: x, y, z, m.
"""
n = particles.shape[0]
forces = np.zeros((n, 3))

View File

@ -8,50 +8,74 @@ logger = logging.getLogger(__name__)
#### Version 1 - keeping the derivative of phi
'''def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
"""
Computes the gravitational forces between a set of particles using a mesh.
Computes the gravitational force acting on a set of particles using a mesh-based approach.
Assumes that the particles array has the following columns: x, y, z, m.
"""
if particles.shape[1] != 4:
raise ValueError("Particles array must have 4 columns: x, y, z, m")
mesh, axis = to_mesh(particles, n_grid, mapping)
# show_mesh_information(mesh, "Initial mesh")
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]")
grad_phi = mesh_poisson(mesh, G)
# show_mesh_information(mesh, "Mesh potential")
mesh, axis = to_mesh(particles, n_grid, mapping)
spacing = np.abs(axis[1] - axis[0])
logger.debug(f"Using mesh spacing: {spacing}")
# we want a density mesh:
cell_volume = spacing**3
rho = mesh / cell_volume
if logger.isEnabledFor(logging.DEBUG):
show_mesh_information(mesh, "Density mesh")
# compute the potential and its gradient
phi_grad = mesh_poisson(rho, G, spacing)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
show_mesh_information(phi_grad[0], "Potential gradient (x-direction)")
# compute the particle forces from the mesh potential
forces = np.zeros_like(particles[:, :3])
for i, p in enumerate(particles):
ijk = np.digitize(p, axis) - 1
forces[i] = -grad_phi[ijk[0], ijk[1], ijk[2]] * p[3]
logger.debug(f"Particle {p} maps to cell {ijk}")
# this gives 4 entries since p[3] the mass is digitized as well -> this is meaningless and we discard it
# logger.debug(f"Particle {p} maps to cell {ijk}")
forces[i] = - p[3] * phi_grad[..., ijk[0], ijk[1], ijk[2]]
return forces
def mesh_poisson(mesh: np.ndarray, G: float) -> np.ndarray:
def mesh_poisson(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray:
"""
'Solves' the poisson equation for the mesh using the FFT.
Returns the gradient of the potential since this is required for the force computation.
Solves the poisson equation for the mesh using the FFT.
Returns the derivative of the potential - grad phi
"""
rho_hat = fft.fftn(mesh)
# the laplacian in fourier space takes the form of a multiplication
k = np.fft.fftfreq(mesh.shape[0])
k = fft.fftfreq(mesh.shape[0], spacing)
# shift the zero frequency to the center
k = np.fft.fftshift(k)
# TODO: probably need to take the actual mesh bounds into account
k = fft.fftshift(k)
kx, ky, kz = np.meshgrid(k, k, k)
k_vec = np.array([kx, ky, kz])
logger.debug(f"Got k_square with: {k_vec.shape}, {np.max(k_vec)} {np.min(k_vec)}")
k_vec = np.stack([kx, ky, kz], axis=0)
k_sr = kx**2 + ky**2 + kz**2
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}")
logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}")
show_mesh_information(np.abs(k_sr), "k_square")
grad_phi_hat = 4 * np.pi * G * rho_hat / (1j * k_vec * 2 * np.pi)
k_sr[k_sr == 0] = np.inf
k_inv = k_vec / k_sr # allows for element-wise division
# the inverse fourier transform gives the potential (or its gradient)
logger.debug(f"Proceeding to poisson equation with {rho_hat.shape=}, {k_inv.shape=}")
grad_phi_hat = - 4 * np.pi * G * rho_hat * k_inv * 1j
# nabla^2 phi => -i * k * nabla phi = 4 pi G rho => nabla phi = - i * rho * k / k^2
# todo: check minus
grad_phi = np.real(fft.ifftn(grad_phi_hat))
return grad_phi
'''
#### Version 2 - only storing the scalar potential
def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
@ -65,24 +89,23 @@ def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callab
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]")
mesh, axis = to_mesh(particles, n_grid, mapping)
spacing = axis[1] - axis[0]
spacing = np.abs(axis[1] - axis[0])
logger.debug(f"Using mesh spacing: {spacing}")
# we want a density mesh:
cell_volume = spacing**3
rho = mesh / cell_volume
if logger.level >= logging.DEBUG:
if logger.isEnabledFor(logging.DEBUG):
show_mesh_information(mesh, "Density mesh")
# compute the potential and its gradient
phi = mesh_poisson_v2(rho, G, spacing)
logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}")
phi_grad = np.stack(np.gradient(phi, spacing), axis=0)
if logger.level >= logging.DEBUG:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
show_mesh_information(phi, "Potential mesh")
show_mesh_information(phi_grad[0], "Potential gradient (x-direction)")
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
# compute the particle forces from the mesh potential
forces = np.zeros_like(particles[:, :3])
@ -105,11 +128,11 @@ def mesh_poisson_v2(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray:
rho_hat = fft.fftn(mesh)
k = fft.fftfreq(mesh.shape[0], spacing)
# shift the zero frequency to the center
k = np.fft.fftshift(k)
k = fft.fftshift(k)
kx, ky, kz = np.meshgrid(k, k, k)
k_sr = kx**2 + ky**2 + kz**2
if logger.level >= logging.DEBUG:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}")
logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}")
show_mesh_information(np.abs(k_sr), "k_square")
@ -129,10 +152,11 @@ def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.n
"""
Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid.
Assumes that the particles array has the following columns: x, y, z, ..., m.
Uses the mass of the particles and a smoothing function to detemine the contribution to each cell.
Uses the mapping function to detemine the contribution to each cell.
"""
if particles.shape[1] < 4:
raise ValueError("Particles array must have at least 4 columns: x, y, z, m")
# axis provide an easy way to map the particles to the mesh
max_pos = np.max(particles[:, :3])
axis = np.linspace(-max_pos, max_pos, n_grid)
@ -202,7 +226,8 @@ def mesh_plot_3d(mesh: np.ndarray, name: str):
fig = plt.figure()
fig.suptitle(f"{name} - {mesh.shape}")
ax = fig.add_subplot(111, projection='3d')
ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis')
sc = ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis')
plt.colorbar(sc, ax=ax, label='Density')
plt.show()