## Implementation of a mesh based force solver import numpy as np import matplotlib.pyplot as plt from scipy import fft import logging logger = logging.getLogger(__name__) #### Version 1 - keeping the derivative of phi def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray: """ Computes the gravitational force acting on a set of particles using a mesh-based approach. Assumes that the particles array has the following columns: x, y, z, m. """ if particles.shape[1] != 4: raise ValueError("Particles array must have 4 columns: x, y, z, m") logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]") mesh, axis = to_mesh(particles, n_grid, mapping) spacing = np.abs(axis[1] - axis[0]) logger.debug(f"Using mesh spacing: {spacing}") # we want a density mesh: cell_volume = spacing**3 rho = mesh / cell_volume if logger.isEnabledFor(logging.DEBUG): show_mesh_information(mesh, "Density mesh") # compute the potential and its gradient phi_grad = mesh_poisson(rho, G, spacing) if logger.isEnabledFor(logging.DEBUG): logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}") show_mesh_information(phi_grad[0], "Potential gradient (x-direction)") # compute the particle forces from the mesh potential forces = np.zeros_like(particles[:, :3]) for i, p in enumerate(particles): ijk = np.digitize(p, axis) - 1 logger.debug(f"Particle {p} maps to cell {ijk}") # this gives 4 entries since p[3] the mass is digitized as well -> this is meaningless and we discard it # logger.debug(f"Particle {p} maps to cell {ijk}") forces[i] = - p[3] * phi_grad[..., ijk[0], ijk[1], ijk[2]] return forces def mesh_poisson(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray: """ Solves the poisson equation for the mesh using the FFT. Returns the derivative of the potential - grad phi """ rho_hat = fft.fftn(mesh) k = fft.fftfreq(mesh.shape[0], spacing) # shift the zero frequency to the center k = fft.fftshift(k) kx, ky, kz = np.meshgrid(k, k, k) k_vec = np.stack([kx, ky, kz], axis=0) k_sr = kx**2 + ky**2 + kz**2 if logger.isEnabledFor(logging.DEBUG): logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}") logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}") show_mesh_information(np.abs(k_sr), "k_square") k_sr[k_sr == 0] = np.inf k_inv = k_vec / k_sr # allows for element-wise division logger.debug(f"Proceeding to poisson equation with {rho_hat.shape=}, {k_inv.shape=}") grad_phi_hat = - 4 * np.pi * G * rho_hat * k_inv * 1j # nabla^2 phi => -i * k * nabla phi = 4 pi G rho => nabla phi = - i * rho * k / k^2 # todo: check minus grad_phi = np.real(fft.ifftn(grad_phi_hat)) return grad_phi #### Version 2 - only storing the scalar potential def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray: """ Computes the gravitational force acting on a set of particles using a mesh-based approach. Assumes that the particles array has the following columns: x, y, z, m. """ if particles.shape[1] != 4: raise ValueError("Particles array must have 4 columns: x, y, z, m") logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]") mesh, axis = to_mesh(particles, n_grid, mapping) spacing = np.abs(axis[1] - axis[0]) logger.debug(f"Using mesh spacing: {spacing}") # we want a density mesh: cell_volume = spacing**3 rho = mesh / cell_volume if logger.isEnabledFor(logging.DEBUG): show_mesh_information(mesh, "Density mesh") # compute the potential and its gradient phi = mesh_poisson_v2(rho, G, spacing) logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}") phi_grad = np.stack(np.gradient(phi, spacing), axis=0) if logger.isEnabledFor(logging.DEBUG): logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}") show_mesh_information(phi, "Potential mesh") show_mesh_information(phi_grad[0], "Potential gradient (x-direction)") # compute the particle forces from the mesh potential forces = np.zeros_like(particles[:, :3]) for i, p in enumerate(particles): ijk = np.digitize(p, axis) - 1 # this gives 4 entries since p[3] the mass is digitized as well -> this is meaningless and we discard it # logger.debug(f"Particle {p} maps to cell {ijk}") forces[i] = - p[3] * phi_grad[..., ijk[0], ijk[1], ijk[2]] # TODO remove factor of 10 # TODO could also index phi_grad the other way around? return forces def mesh_poisson_v2(mesh: np.ndarray, G: float, spacing: float) -> np.ndarray: """ Solves the poisson equation for the mesh using the FFT. Returns the scalar potential. """ rho_hat = fft.fftn(mesh) k = fft.fftfreq(mesh.shape[0], spacing) # shift the zero frequency to the center k = fft.fftshift(k) kx, ky, kz = np.meshgrid(k, k, k) k_sr = kx**2 + ky**2 + kz**2 if logger.isEnabledFor(logging.DEBUG): logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}") logger.debug(f"Count of ksquare zeros: {np.sum(k_sr == 0)}") show_mesh_information(np.abs(k_sr), "k_square") # avoid division by zero # TODO: review this k_sr[k_sr == 0] = np.inf logger.debug(f"Proceeding to poisson equation with {rho_hat.shape=}, {k_sr.shape=}") phi_hat = - 4 * np.pi * G * rho_hat / k_sr # - comes from i squared # TODO: 4pi stays since the backtransform removes the 1/2pi factor phi = np.real(fft.ifftn(phi_hat)) return phi #### Helper functions for star mapping def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.ndarray, np.ndarray]: """ Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid. Assumes that the particles array has the following columns: x, y, z, ..., m. Uses the mapping function to detemine the contribution to each cell. """ if particles.shape[1] < 4: raise ValueError("Particles array must have at least 4 columns: x, y, z, m") # axis provide an easy way to map the particles to the mesh max_pos = np.max(particles[:, :3]) axis = np.linspace(-max_pos, max_pos, n_grid) mesh_grid = np.meshgrid(axis, axis, axis) mesh = np.zeros_like(mesh_grid[0]) for p in particles: m = p[-1] # spread the star onto cells through the shape function, taking into account the mass ijks, weights = mapping(p, axis) for ijk, weight in zip(ijks, weights): mesh[ijk[0], ijk[1], ijk[2]] += weight * m return mesh, axis def particle_to_cells_nn(particle, axis): # find the single cell that contains the particle ijk = np.digitize(particle, axis) - 1 # the weight is obviously 1 return [ijk], [1] bbox = np.array([ [1, 0, 0], [-1, 0, 0], [1, 1, 0], [-1, -1, 0], [1, 1, 1], [-1, -1, 1], [1, 1, -1], [-1, -1, -1] ]) def particle_to_cells_cic(particle, axis, width): # create a virtual cell around the particle and check the intersections bounding_cell = particle + width * bbox # find all the cells that intersect with the virtual cell ijks = [] weights = [] for b in bounding_cell: # TODO: this is not the correct weight w = np.linalg.norm(particle - b) ijk = np.digitize(b, axis) - 1 ijks.append(ijk) weights.append(w) # ensure that the weights sum to 1 weights = np.array(weights) weights /= np.sum(weights) return ijks, weights #### Helper functions for mesh plotting def show_mesh_information(mesh: np.ndarray, name: str): logger.info(f"Mesh information for {name}") logger.info(f"Total mapped mass: {np.sum(mesh):.0f}") logger.info(f"Max cell value: {np.max(mesh)}") logger.info(f"Min cell value: {np.min(mesh)}") logger.info(f"Mean cell value: {np.mean(mesh)}") mesh_plot_3d(mesh, name) mesh_plot_2d(mesh, name) def mesh_plot_3d(mesh: np.ndarray, name: str): fig = plt.figure() fig.suptitle(f"{name} - {mesh.shape}") ax = fig.add_subplot(111, projection='3d') sc = ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis') plt.colorbar(sc, ax=ax, label='Density') plt.show() def mesh_plot_2d(mesh: np.ndarray, name: str, only_z: bool = False): fig = plt.figure() fig.suptitle(f"{name} - {mesh.shape}") if only_z: plt.imshow(np.sum(mesh, axis=2), cmap='viridis', origin='lower') else: axs = fig.subplots(1, 3) axs[0].imshow(np.sum(mesh, axis=0), origin='lower') axs[0].set_title("Flattened in x") axs[1].imshow(np.sum(mesh, axis=1), origin='lower') axs[1].set_title("Flattened in y") axs[2].imshow(np.sum(mesh, axis=2), origin='lower') axs[2].set_title("Flattened in z") plt.show()