mesh solver finally with sensible results
This commit is contained in:
		| @@ -12,3 +12,7 @@ from .forces_mesh import * | ||||
|  | ||||
| # Helpers for solving the IVP and having time evolution | ||||
| from .integrate import * | ||||
|  | ||||
|  | ||||
| # Fully mesh-based solver | ||||
| from .solver_mesh import * | ||||
|   | ||||
| @@ -75,6 +75,7 @@ def analytical_forces(particles: np.ndarray): | ||||
|     The force on a particle at radius r is simply the force exerted by a point mass with the enclosed mass. | ||||
|     Assumes that the particles array has the following columns: x, y, z, m. | ||||
|     """ | ||||
|  | ||||
|     n = particles.shape[0] | ||||
|     forces = np.zeros((n, 3)) | ||||
|  | ||||
|   | ||||
| @@ -8,50 +8,74 @@ logger = logging.getLogger(__name__) | ||||
|  | ||||
| #### Version 1 - keeping the derivative of phi | ||||
|  | ||||
| '''def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray: | ||||
| def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray: | ||||
|     """ | ||||
|     Computes the gravitational forces between a set of particles using a mesh. | ||||
|     Assumes that the particles array has the following columns: x, y, z, m. | ||||
|     Computes the gravitational force acting on a set of particles using a mesh-based approach. | ||||
|     Assumes that the particles array has the following columns: x, y, z, m.  | ||||
|     """ | ||||
|     if particles.shape[1] != 4: | ||||
|         raise ValueError("Particles array must have 4 columns: x, y, z, m") | ||||
|  | ||||
|     mesh, axis = to_mesh(particles, n_grid, mapping) | ||||
|     # show_mesh_information(mesh, "Initial mesh") | ||||
|     logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]") | ||||
|  | ||||
|     grad_phi = mesh_poisson(mesh, G) | ||||
|     # show_mesh_information(mesh, "Mesh potential") | ||||
|     mesh, axis = to_mesh(particles, n_grid, mapping) | ||||
|     spacing = np.abs(axis[1] - axis[0]) | ||||
|     logger.debug(f"Using mesh spacing: {spacing}") | ||||
|  | ||||
|     # we want a density mesh: | ||||
|     cell_volume = spacing**3 | ||||
|     rho = mesh / cell_volume | ||||
|  | ||||
|     if logger.isEnabledFor(logging.DEBUG): | ||||
|         show_mesh_information(mesh, "Density mesh") | ||||
|  | ||||
|     # compute the potential and its gradient | ||||
|     phi_grad = mesh_poisson(rho, G, spacing) | ||||
|  | ||||
|     if logger.isEnabledFor(logging.DEBUG): | ||||
|         logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}") | ||||
|         show_mesh_information(phi_grad[0], "Potential gradient (x-direction)") | ||||
|  | ||||
|     # compute the particle forces from the mesh potential | ||||
|     forces = np.zeros_like(particles[:, :3]) | ||||
|     for i, p in enumerate(particles): | ||||
|         ijk = np.digitize(p, axis) - 1 | ||||
|         forces[i] = -grad_phi[ijk[0], ijk[1], ijk[2]] * p[3] | ||||
|         logger.debug(f"Particle {p} maps to cell {ijk}") | ||||
|         # this gives 4 entries since p[3] the mass is digitized as well -> this is meaningless and we discard it | ||||
|         # logger.debug(f"Particle {p} maps to cell {ijk}") | ||||
|         forces[i] = - p[3] * phi_grad[..., ijk[0], ijk[1], ijk[2]] | ||||
|  | ||||
|     return forces | ||||
|  | ||||
|  | ||||
| def mesh_poisson(mesh: np.ndarray, G: float) -> np.ndarray: | ||||
| def mesh_poisson(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray: | ||||
|     """ | ||||
|     'Solves' the poisson equation for the mesh using the FFT. | ||||
|     Returns the gradient of the potential since this is required for the force computation. | ||||
|     Solves the poisson equation for the mesh using the FFT. | ||||
|     Returns the derivative of the potential - grad phi | ||||
|     """ | ||||
|     rho_hat = fft.fftn(mesh) | ||||
|     # the laplacian in fourier space takes the form of a multiplication | ||||
|     k = np.fft.fftfreq(mesh.shape[0]) | ||||
|     k = fft.fftfreq(mesh.shape[0], spacing) | ||||
|     # shift the zero frequency to the center | ||||
|     k = np.fft.fftshift(k) | ||||
|     # TODO: probably need to take the actual mesh bounds into account | ||||
|     k = fft.fftshift(k) | ||||
|  | ||||
|     kx, ky, kz = np.meshgrid(k, k, k) | ||||
|     k_vec = np.array([kx, ky, kz]) | ||||
|     logger.debug(f"Got k_square with: {k_vec.shape}, {np.max(k_vec)} {np.min(k_vec)}") | ||||
|     k_vec = np.stack([kx, ky, kz], axis=0) | ||||
|     k_sr = kx**2 + ky**2 + kz**2 | ||||
|     if logger.isEnabledFor(logging.DEBUG): | ||||
|         logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}") | ||||
|         logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}") | ||||
|         show_mesh_information(np.abs(k_sr), "k_square") | ||||
|  | ||||
|     grad_phi_hat = 4 * np.pi * G * rho_hat / (1j * k_vec * 2 * np.pi) | ||||
|     k_sr[k_sr == 0] = np.inf | ||||
|     k_inv = k_vec / k_sr # allows for element-wise division | ||||
|  | ||||
|     # the inverse fourier transform gives the potential (or its gradient) | ||||
|     logger.debug(f"Proceeding to poisson equation with {rho_hat.shape=}, {k_inv.shape=}") | ||||
|     grad_phi_hat = - 4 * np.pi * G * rho_hat * k_inv * 1j | ||||
|     # nabla^2 phi => -i * k * nabla phi = 4 pi G rho => nabla phi = - i * rho * k / k^2 | ||||
|     # todo: check minus | ||||
|     grad_phi = np.real(fft.ifftn(grad_phi_hat)) | ||||
|     return grad_phi | ||||
| ''' | ||||
|  | ||||
|  | ||||
| #### Version 2 - only storing the scalar potential | ||||
| def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray: | ||||
| @@ -65,24 +89,23 @@ def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callab | ||||
|     logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]") | ||||
|  | ||||
|     mesh, axis = to_mesh(particles, n_grid, mapping) | ||||
|     spacing = axis[1] - axis[0] | ||||
|     spacing = np.abs(axis[1] - axis[0]) | ||||
|     logger.debug(f"Using mesh spacing: {spacing}") | ||||
|  | ||||
|     # we want a density mesh: | ||||
|     cell_volume = spacing**3 | ||||
|     rho = mesh / cell_volume | ||||
|  | ||||
|     if logger.level >= logging.DEBUG: | ||||
|     if logger.isEnabledFor(logging.DEBUG): | ||||
|         show_mesh_information(mesh, "Density mesh") | ||||
|  | ||||
|     # compute the potential and its gradient | ||||
|     phi = mesh_poisson_v2(rho, G, spacing) | ||||
|     logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}") | ||||
|     phi_grad = np.stack(np.gradient(phi, spacing), axis=0) | ||||
|     if logger.level >= logging.DEBUG: | ||||
|     if logger.isEnabledFor(logging.DEBUG): | ||||
|         logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}") | ||||
|         show_mesh_information(phi, "Potential mesh") | ||||
|         show_mesh_information(phi_grad[0], "Potential gradient (x-direction)") | ||||
|         logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}") | ||||
|  | ||||
|     # compute the particle forces from the mesh potential | ||||
|     forces = np.zeros_like(particles[:, :3]) | ||||
| @@ -105,11 +128,11 @@ def mesh_poisson_v2(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray: | ||||
|     rho_hat = fft.fftn(mesh) | ||||
|     k = fft.fftfreq(mesh.shape[0], spacing) | ||||
|     # shift the zero frequency to the center | ||||
|     k = np.fft.fftshift(k) | ||||
|     k = fft.fftshift(k) | ||||
|  | ||||
|     kx, ky, kz = np.meshgrid(k, k, k) | ||||
|     k_sr = kx**2 + ky**2 + kz**2 | ||||
|     if logger.level >= logging.DEBUG: | ||||
|     if logger.isEnabledFor(logging.DEBUG): | ||||
|         logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}") | ||||
|         logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}") | ||||
|         show_mesh_information(np.abs(k_sr), "k_square") | ||||
| @@ -129,10 +152,11 @@ def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.n | ||||
|     """ | ||||
|     Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid. | ||||
|     Assumes that the particles array has the following columns: x, y, z, ..., m. | ||||
|     Uses the mass of the particles and a smoothing function to detemine the contribution to each cell. | ||||
|     Uses the mapping function to detemine the contribution to each cell. | ||||
|     """ | ||||
|     if particles.shape[1] < 4: | ||||
|         raise ValueError("Particles array must have at least 4 columns: x, y, z, m") | ||||
|  | ||||
|     # axis provide an easy way to map the particles to the mesh | ||||
|     max_pos = np.max(particles[:, :3]) | ||||
|     axis = np.linspace(-max_pos, max_pos, n_grid) | ||||
| @@ -202,7 +226,8 @@ def mesh_plot_3d(mesh: np.ndarray, name: str): | ||||
|     fig = plt.figure() | ||||
|     fig.suptitle(f"{name} - {mesh.shape}") | ||||
|     ax = fig.add_subplot(111, projection='3d') | ||||
|     ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis') | ||||
|     sc = ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis') | ||||
|     plt.colorbar(sc, ax=ax, label='Density') | ||||
|     plt.show() | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user