mesh solver finally with sensible results
This commit is contained in:
parent
65b893b41a
commit
9d54e9743e
483
nbody/copy.ipynb
483
nbody/copy.ipynb
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -12,3 +12,7 @@ from .forces_mesh import *
|
|||||||
|
|
||||||
# Helpers for solving the IVP and having time evolution
|
# Helpers for solving the IVP and having time evolution
|
||||||
from .integrate import *
|
from .integrate import *
|
||||||
|
|
||||||
|
|
||||||
|
# Fully mesh-based solver
|
||||||
|
from .solver_mesh import *
|
||||||
|
@ -75,6 +75,7 @@ def analytical_forces(particles: np.ndarray):
|
|||||||
The force on a particle at radius r is simply the force exerted by a point mass with the enclosed mass.
|
The force on a particle at radius r is simply the force exerted by a point mass with the enclosed mass.
|
||||||
Assumes that the particles array has the following columns: x, y, z, m.
|
Assumes that the particles array has the following columns: x, y, z, m.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
n = particles.shape[0]
|
n = particles.shape[0]
|
||||||
forces = np.zeros((n, 3))
|
forces = np.zeros((n, 3))
|
||||||
|
|
||||||
|
@ -8,50 +8,74 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
#### Version 1 - keeping the derivative of phi
|
#### Version 1 - keeping the derivative of phi
|
||||||
|
|
||||||
'''def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
|
def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes the gravitational forces between a set of particles using a mesh.
|
Computes the gravitational force acting on a set of particles using a mesh-based approach.
|
||||||
Assumes that the particles array has the following columns: x, y, z, m.
|
Assumes that the particles array has the following columns: x, y, z, m.
|
||||||
"""
|
"""
|
||||||
if particles.shape[1] != 4:
|
if particles.shape[1] != 4:
|
||||||
raise ValueError("Particles array must have 4 columns: x, y, z, m")
|
raise ValueError("Particles array must have 4 columns: x, y, z, m")
|
||||||
|
|
||||||
mesh, axis = to_mesh(particles, n_grid, mapping)
|
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]")
|
||||||
# show_mesh_information(mesh, "Initial mesh")
|
|
||||||
|
|
||||||
grad_phi = mesh_poisson(mesh, G)
|
mesh, axis = to_mesh(particles, n_grid, mapping)
|
||||||
# show_mesh_information(mesh, "Mesh potential")
|
spacing = np.abs(axis[1] - axis[0])
|
||||||
|
logger.debug(f"Using mesh spacing: {spacing}")
|
||||||
|
|
||||||
|
# we want a density mesh:
|
||||||
|
cell_volume = spacing**3
|
||||||
|
rho = mesh / cell_volume
|
||||||
|
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
show_mesh_information(mesh, "Density mesh")
|
||||||
|
|
||||||
|
# compute the potential and its gradient
|
||||||
|
phi_grad = mesh_poisson(rho, G, spacing)
|
||||||
|
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
|
||||||
|
show_mesh_information(phi_grad[0], "Potential gradient (x-direction)")
|
||||||
|
|
||||||
# compute the particle forces from the mesh potential
|
# compute the particle forces from the mesh potential
|
||||||
forces = np.zeros_like(particles[:, :3])
|
forces = np.zeros_like(particles[:, :3])
|
||||||
for i, p in enumerate(particles):
|
for i, p in enumerate(particles):
|
||||||
ijk = np.digitize(p, axis) - 1
|
ijk = np.digitize(p, axis) - 1
|
||||||
forces[i] = -grad_phi[ijk[0], ijk[1], ijk[2]] * p[3]
|
logger.debug(f"Particle {p} maps to cell {ijk}")
|
||||||
|
# this gives 4 entries since p[3] the mass is digitized as well -> this is meaningless and we discard it
|
||||||
|
# logger.debug(f"Particle {p} maps to cell {ijk}")
|
||||||
|
forces[i] = - p[3] * phi_grad[..., ijk[0], ijk[1], ijk[2]]
|
||||||
|
|
||||||
return forces
|
return forces
|
||||||
|
|
||||||
|
|
||||||
def mesh_poisson(mesh: np.ndarray, G: float) -> np.ndarray:
|
def mesh_poisson(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
'Solves' the poisson equation for the mesh using the FFT.
|
Solves the poisson equation for the mesh using the FFT.
|
||||||
Returns the gradient of the potential since this is required for the force computation.
|
Returns the derivative of the potential - grad phi
|
||||||
"""
|
"""
|
||||||
rho_hat = fft.fftn(mesh)
|
rho_hat = fft.fftn(mesh)
|
||||||
# the laplacian in fourier space takes the form of a multiplication
|
k = fft.fftfreq(mesh.shape[0], spacing)
|
||||||
k = np.fft.fftfreq(mesh.shape[0])
|
|
||||||
# shift the zero frequency to the center
|
# shift the zero frequency to the center
|
||||||
k = np.fft.fftshift(k)
|
k = fft.fftshift(k)
|
||||||
# TODO: probably need to take the actual mesh bounds into account
|
|
||||||
kx, ky, kz = np.meshgrid(k, k, k)
|
kx, ky, kz = np.meshgrid(k, k, k)
|
||||||
k_vec = np.array([kx, ky, kz])
|
k_vec = np.stack([kx, ky, kz], axis=0)
|
||||||
logger.debug(f"Got k_square with: {k_vec.shape}, {np.max(k_vec)} {np.min(k_vec)}")
|
k_sr = kx**2 + ky**2 + kz**2
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}")
|
||||||
|
logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}")
|
||||||
|
show_mesh_information(np.abs(k_sr), "k_square")
|
||||||
|
|
||||||
grad_phi_hat = 4 * np.pi * G * rho_hat / (1j * k_vec * 2 * np.pi)
|
k_sr[k_sr == 0] = np.inf
|
||||||
|
k_inv = k_vec / k_sr # allows for element-wise division
|
||||||
|
|
||||||
# the inverse fourier transform gives the potential (or its gradient)
|
logger.debug(f"Proceeding to poisson equation with {rho_hat.shape=}, {k_inv.shape=}")
|
||||||
|
grad_phi_hat = - 4 * np.pi * G * rho_hat * k_inv * 1j
|
||||||
|
# nabla^2 phi => -i * k * nabla phi = 4 pi G rho => nabla phi = - i * rho * k / k^2
|
||||||
|
# todo: check minus
|
||||||
grad_phi = np.real(fft.ifftn(grad_phi_hat))
|
grad_phi = np.real(fft.ifftn(grad_phi_hat))
|
||||||
return grad_phi
|
return grad_phi
|
||||||
'''
|
|
||||||
|
|
||||||
#### Version 2 - only storing the scalar potential
|
#### Version 2 - only storing the scalar potential
|
||||||
def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
|
def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray:
|
||||||
@ -65,24 +89,23 @@ def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callab
|
|||||||
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]")
|
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]")
|
||||||
|
|
||||||
mesh, axis = to_mesh(particles, n_grid, mapping)
|
mesh, axis = to_mesh(particles, n_grid, mapping)
|
||||||
spacing = axis[1] - axis[0]
|
spacing = np.abs(axis[1] - axis[0])
|
||||||
logger.debug(f"Using mesh spacing: {spacing}")
|
logger.debug(f"Using mesh spacing: {spacing}")
|
||||||
|
|
||||||
# we want a density mesh:
|
# we want a density mesh:
|
||||||
cell_volume = spacing**3
|
cell_volume = spacing**3
|
||||||
rho = mesh / cell_volume
|
rho = mesh / cell_volume
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
if logger.level >= logging.DEBUG:
|
|
||||||
show_mesh_information(mesh, "Density mesh")
|
show_mesh_information(mesh, "Density mesh")
|
||||||
|
|
||||||
# compute the potential and its gradient
|
# compute the potential and its gradient
|
||||||
phi = mesh_poisson_v2(rho, G, spacing)
|
phi = mesh_poisson_v2(rho, G, spacing)
|
||||||
logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}")
|
logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}")
|
||||||
phi_grad = np.stack(np.gradient(phi, spacing), axis=0)
|
phi_grad = np.stack(np.gradient(phi, spacing), axis=0)
|
||||||
if logger.level >= logging.DEBUG:
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
|
||||||
show_mesh_information(phi, "Potential mesh")
|
show_mesh_information(phi, "Potential mesh")
|
||||||
show_mesh_information(phi_grad[0], "Potential gradient (x-direction)")
|
show_mesh_information(phi_grad[0], "Potential gradient (x-direction)")
|
||||||
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
|
|
||||||
|
|
||||||
# compute the particle forces from the mesh potential
|
# compute the particle forces from the mesh potential
|
||||||
forces = np.zeros_like(particles[:, :3])
|
forces = np.zeros_like(particles[:, :3])
|
||||||
@ -105,11 +128,11 @@ def mesh_poisson_v2(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray:
|
|||||||
rho_hat = fft.fftn(mesh)
|
rho_hat = fft.fftn(mesh)
|
||||||
k = fft.fftfreq(mesh.shape[0], spacing)
|
k = fft.fftfreq(mesh.shape[0], spacing)
|
||||||
# shift the zero frequency to the center
|
# shift the zero frequency to the center
|
||||||
k = np.fft.fftshift(k)
|
k = fft.fftshift(k)
|
||||||
|
|
||||||
kx, ky, kz = np.meshgrid(k, k, k)
|
kx, ky, kz = np.meshgrid(k, k, k)
|
||||||
k_sr = kx**2 + ky**2 + kz**2
|
k_sr = kx**2 + ky**2 + kz**2
|
||||||
if logger.level >= logging.DEBUG:
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}")
|
logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}")
|
||||||
logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}")
|
logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}")
|
||||||
show_mesh_information(np.abs(k_sr), "k_square")
|
show_mesh_information(np.abs(k_sr), "k_square")
|
||||||
@ -129,10 +152,11 @@ def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.n
|
|||||||
"""
|
"""
|
||||||
Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid.
|
Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid.
|
||||||
Assumes that the particles array has the following columns: x, y, z, ..., m.
|
Assumes that the particles array has the following columns: x, y, z, ..., m.
|
||||||
Uses the mass of the particles and a smoothing function to detemine the contribution to each cell.
|
Uses the mapping function to detemine the contribution to each cell.
|
||||||
"""
|
"""
|
||||||
if particles.shape[1] < 4:
|
if particles.shape[1] < 4:
|
||||||
raise ValueError("Particles array must have at least 4 columns: x, y, z, m")
|
raise ValueError("Particles array must have at least 4 columns: x, y, z, m")
|
||||||
|
|
||||||
# axis provide an easy way to map the particles to the mesh
|
# axis provide an easy way to map the particles to the mesh
|
||||||
max_pos = np.max(particles[:, :3])
|
max_pos = np.max(particles[:, :3])
|
||||||
axis = np.linspace(-max_pos, max_pos, n_grid)
|
axis = np.linspace(-max_pos, max_pos, n_grid)
|
||||||
@ -202,7 +226,8 @@ def mesh_plot_3d(mesh: np.ndarray, name: str):
|
|||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
fig.suptitle(f"{name} - {mesh.shape}")
|
fig.suptitle(f"{name} - {mesh.shape}")
|
||||||
ax = fig.add_subplot(111, projection='3d')
|
ax = fig.add_subplot(111, projection='3d')
|
||||||
ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis')
|
sc = ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis')
|
||||||
|
plt.colorbar(sc, ax=ax, label='Density')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user