mesh solver finally with sensible results
This commit is contained in:
parent
65b893b41a
commit
9d54e9743e
483
nbody/copy.ipynb
483
nbody/copy.ipynb
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -12,3 +12,7 @@ from .forces_mesh import *
|
||||
|
||||
# Helpers for solving the IVP and having time evolution
|
||||
from .integrate import *
|
||||
|
||||
|
||||
# Fully mesh-based solver
|
||||
from .solver_mesh import *
|
||||
|
@ -75,6 +75,7 @@ def analytical_forces(particles: np.ndarray):
|
||||
The force on a particle at radius r is simply the force exerted by a point mass with the enclosed mass.
|
||||
Assumes that the particles array has the following columns: x, y, z, m.
|
||||
"""
|
||||
|
||||
n = particles.shape[0]
|
||||
forces = np.zeros((n, 3))
|
||||
|
||||
|
@ -8,50 +8,74 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
#### Version 1 - keeping the derivative of phi
|
||||
|
||||
'''def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
|
||||
def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
|
||||
"""
|
||||
Computes the gravitational forces between a set of particles using a mesh.
|
||||
Computes the gravitational force acting on a set of particles using a mesh-based approach.
|
||||
Assumes that the particles array has the following columns: x, y, z, m.
|
||||
"""
|
||||
if particles.shape[1] != 4:
|
||||
raise ValueError("Particles array must have 4 columns: x, y, z, m")
|
||||
|
||||
mesh, axis = to_mesh(particles, n_grid, mapping)
|
||||
# show_mesh_information(mesh, "Initial mesh")
|
||||
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]")
|
||||
|
||||
grad_phi = mesh_poisson(mesh, G)
|
||||
# show_mesh_information(mesh, "Mesh potential")
|
||||
mesh, axis = to_mesh(particles, n_grid, mapping)
|
||||
spacing = np.abs(axis[1] - axis[0])
|
||||
logger.debug(f"Using mesh spacing: {spacing}")
|
||||
|
||||
# we want a density mesh:
|
||||
cell_volume = spacing**3
|
||||
rho = mesh / cell_volume
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
show_mesh_information(mesh, "Density mesh")
|
||||
|
||||
# compute the potential and its gradient
|
||||
phi_grad = mesh_poisson(rho, G, spacing)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
|
||||
show_mesh_information(phi_grad[0], "Potential gradient (x-direction)")
|
||||
|
||||
# compute the particle forces from the mesh potential
|
||||
forces = np.zeros_like(particles[:, :3])
|
||||
for i, p in enumerate(particles):
|
||||
ijk = np.digitize(p, axis) - 1
|
||||
forces[i] = -grad_phi[ijk[0], ijk[1], ijk[2]] * p[3]
|
||||
logger.debug(f"Particle {p} maps to cell {ijk}")
|
||||
# this gives 4 entries since p[3] the mass is digitized as well -> this is meaningless and we discard it
|
||||
# logger.debug(f"Particle {p} maps to cell {ijk}")
|
||||
forces[i] = - p[3] * phi_grad[..., ijk[0], ijk[1], ijk[2]]
|
||||
|
||||
return forces
|
||||
|
||||
|
||||
def mesh_poisson(mesh: np.ndarray, G: float) -> np.ndarray:
|
||||
def mesh_poisson(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray:
|
||||
"""
|
||||
'Solves' the poisson equation for the mesh using the FFT.
|
||||
Returns the gradient of the potential since this is required for the force computation.
|
||||
Solves the poisson equation for the mesh using the FFT.
|
||||
Returns the derivative of the potential - grad phi
|
||||
"""
|
||||
rho_hat = fft.fftn(mesh)
|
||||
# the laplacian in fourier space takes the form of a multiplication
|
||||
k = np.fft.fftfreq(mesh.shape[0])
|
||||
k = fft.fftfreq(mesh.shape[0], spacing)
|
||||
# shift the zero frequency to the center
|
||||
k = np.fft.fftshift(k)
|
||||
# TODO: probably need to take the actual mesh bounds into account
|
||||
k = fft.fftshift(k)
|
||||
|
||||
kx, ky, kz = np.meshgrid(k, k, k)
|
||||
k_vec = np.array([kx, ky, kz])
|
||||
logger.debug(f"Got k_square with: {k_vec.shape}, {np.max(k_vec)} {np.min(k_vec)}")
|
||||
k_vec = np.stack([kx, ky, kz], axis=0)
|
||||
k_sr = kx**2 + ky**2 + kz**2
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}")
|
||||
logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}")
|
||||
show_mesh_information(np.abs(k_sr), "k_square")
|
||||
|
||||
grad_phi_hat = 4 * np.pi * G * rho_hat / (1j * k_vec * 2 * np.pi)
|
||||
k_sr[k_sr == 0] = np.inf
|
||||
k_inv = k_vec / k_sr # allows for element-wise division
|
||||
|
||||
# the inverse fourier transform gives the potential (or its gradient)
|
||||
logger.debug(f"Proceeding to poisson equation with {rho_hat.shape=}, {k_inv.shape=}")
|
||||
grad_phi_hat = - 4 * np.pi * G * rho_hat * k_inv * 1j
|
||||
# nabla^2 phi => -i * k * nabla phi = 4 pi G rho => nabla phi = - i * rho * k / k^2
|
||||
# todo: check minus
|
||||
grad_phi = np.real(fft.ifftn(grad_phi_hat))
|
||||
return grad_phi
|
||||
'''
|
||||
|
||||
|
||||
#### Version 2 - only storing the scalar potential
|
||||
def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
|
||||
@ -65,24 +89,23 @@ def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callab
|
||||
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]")
|
||||
|
||||
mesh, axis = to_mesh(particles, n_grid, mapping)
|
||||
spacing = axis[1] - axis[0]
|
||||
spacing = np.abs(axis[1] - axis[0])
|
||||
logger.debug(f"Using mesh spacing: {spacing}")
|
||||
|
||||
# we want a density mesh:
|
||||
cell_volume = spacing**3
|
||||
rho = mesh / cell_volume
|
||||
|
||||
if logger.level >= logging.DEBUG:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
show_mesh_information(mesh, "Density mesh")
|
||||
|
||||
# compute the potential and its gradient
|
||||
phi = mesh_poisson_v2(rho, G, spacing)
|
||||
logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}")
|
||||
phi_grad = np.stack(np.gradient(phi, spacing), axis=0)
|
||||
if logger.level >= logging.DEBUG:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
|
||||
show_mesh_information(phi, "Potential mesh")
|
||||
show_mesh_information(phi_grad[0], "Potential gradient (x-direction)")
|
||||
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
|
||||
|
||||
# compute the particle forces from the mesh potential
|
||||
forces = np.zeros_like(particles[:, :3])
|
||||
@ -105,11 +128,11 @@ def mesh_poisson_v2(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray:
|
||||
rho_hat = fft.fftn(mesh)
|
||||
k = fft.fftfreq(mesh.shape[0], spacing)
|
||||
# shift the zero frequency to the center
|
||||
k = np.fft.fftshift(k)
|
||||
k = fft.fftshift(k)
|
||||
|
||||
kx, ky, kz = np.meshgrid(k, k, k)
|
||||
k_sr = kx**2 + ky**2 + kz**2
|
||||
if logger.level >= logging.DEBUG:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}")
|
||||
logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}")
|
||||
show_mesh_information(np.abs(k_sr), "k_square")
|
||||
@ -129,10 +152,11 @@ def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.n
|
||||
"""
|
||||
Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid.
|
||||
Assumes that the particles array has the following columns: x, y, z, ..., m.
|
||||
Uses the mass of the particles and a smoothing function to detemine the contribution to each cell.
|
||||
Uses the mapping function to detemine the contribution to each cell.
|
||||
"""
|
||||
if particles.shape[1] < 4:
|
||||
raise ValueError("Particles array must have at least 4 columns: x, y, z, m")
|
||||
|
||||
# axis provide an easy way to map the particles to the mesh
|
||||
max_pos = np.max(particles[:, :3])
|
||||
axis = np.linspace(-max_pos, max_pos, n_grid)
|
||||
@ -202,7 +226,8 @@ def mesh_plot_3d(mesh: np.ndarray, name: str):
|
||||
fig = plt.figure()
|
||||
fig.suptitle(f"{name} - {mesh.shape}")
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis')
|
||||
sc = ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis')
|
||||
plt.colorbar(sc, ax=ax, label='Density')
|
||||
plt.show()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user