## Implementation of a mesh based force solver import numpy as np import matplotlib.pyplot as plt from scipy import fft import logging logger = logging.getLogger(__name__) #### Version 1 - keeping the derivative of phi def mesh_forces(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray: """ Computes the gravitational forces between a set of particles using a mesh. Assumes that the particles array has the following columns: x, y, z, m. """ if particles.shape[1] != 4: raise ValueError("Particles array must have 4 columns: x, y, z, m") mesh, axis = to_mesh(particles, n_grid, mapping) # show_mesh_information(mesh, "Initial mesh") grad_phi = mesh_poisson(mesh, G) # show_mesh_information(mesh, "Mesh potential") # compute the particle forces from the mesh potential forces = np.zeros_like(particles[:, :3]) for i, p in enumerate(particles): ijk = np.digitize(p, axis) - 1 forces[i] = -grad_phi[ijk[0], ijk[1], ijk[2]] * p[3] return forces def mesh_poisson(mesh: np.ndarray, G: float) -> np.ndarray: """ 'Solves' the poisson equation for the mesh using the FFT. Returns the gradient of the potential since this is required for the force computation. """ rho_hat = fft.fftn(mesh) # the laplacian in fourier space takes the form of a multiplication k = np.fft.fftfreq(mesh.shape[0]) # TODO: probably need to take the actual mesh bounds into account kx, ky, kz = np.meshgrid(k, k, k) k_vec = np.array([kx, ky, kz]) logger.debug(f"Got k_square with: {k_vec.shape}, {np.max(k_vec)} {np.min(k_vec)}") grad_phi_hat = 4 * np.pi * G * rho_hat / (1j * k_vec * 2 * np.pi) # the inverse fourier transform gives the potential (or its gradient) grad_phi = np.real(fft.ifftn(grad_phi_hat)) return grad_phi #### Version 2 - only storing the scalar potential def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callable) -> np.ndarray: if particles.shape[1] != 4: raise ValueError("Particles array must have 4 columns: x, y, z, m") logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]") mesh, axis = to_mesh(particles, n_grid, mapping) if logger.level >= logging.DEBUG: show_mesh_information(mesh, "Density mesh") spacing = axis[1] - axis[0] logger.debug(f"Using mesh spacing: {spacing}") phi = mesh_poisson_v2(mesh, G) logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}") phi_grad = np.stack(np.gradient(phi, spacing), axis=0) if logger.level >= logging.DEBUG: show_mesh_information(phi, "Potential mesh") show_mesh_information(phi_grad[0], "Potential gradient") logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}") forces = np.zeros_like(particles[:, :3]) for i, p in enumerate(particles): ijk = np.digitize(p, axis) - 1 # this gives 4 entries since p[3] the mass is digitized as well -> this is meaningless and we discard it # logger.debug(f"Particle {p} maps to cell {ijk}") forces[i] = - p[3] * phi_grad[..., ijk[0], ijk[1], ijk[2]] return forces def mesh_poisson_v2(mesh: np.ndarray, G: float) -> np.ndarray: rho_hat = fft.fftn(mesh) k = np.fft.fftfreq(mesh.shape[0]) kx, ky, kz = np.meshgrid(k, k, k) k_sr = kx**2 + ky**2 + kz**2 logger.debug(f"Got k_square with: {k_sr.shape}, {np.max(k_sr)} {np.min(k_sr)}") logger.debug(f"Count of zeros: {np.sum(k_sr == 0)}") k_sr[k_sr == 0] = 1e-10 # Add a small epsilon to avoid division by zero phi_hat = - G * rho_hat / k_sr # 4pi cancels, - comes from i squared phi = np.real(fft.ifftn(phi_hat)) return phi #### Helper functions for star mapping def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.ndarray, np.ndarray]: """ Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid. Assumes that the particles array has the following columns: x, y, z, ..., m. Uses the mass of the particles and a smoothing function to detemine the contribution to each cell. """ # axis provide an easy way to map the particles to the mesh max_pos = np.max(particles[:, :3]) axis = np.linspace(-max_pos, max_pos, n_grid) mesh_grid = np.meshgrid(axis, axis, axis) mesh = np.zeros_like(mesh_grid[0]) for p in particles: m = p[-1] if logger.level >= logging.DEBUG and m <= 0: logger.warning(f"Particle with negative mass: {p}") # spread the star onto cells through the shape function, taking into account the mass ijks, weights = mapping(p, axis) for ijk, weight in zip(ijks, weights): mesh[ijk[0], ijk[1], ijk[2]] += weight * m return mesh, axis def particle_to_cells_nn(particle, axis): # find the single cell that contains the particle ijk = np.digitize(particle, axis) - 1 # the weight is obviously 1 return [ijk], [1] def particle_to_cells_cic(particle, axis, width): # create a virtual cell around the particle cell_bounds = [ particle + np.array([1, 0, 0]) * width, particle + np.array([-1, 0, 0]) * width, particle + np.array([1, 1, 0]) * width, particle + np.array([-1, -1, 0]) * width, particle + np.array([1, 1, 1]) * width, particle + np.array([-1, -1, 1]) * width, particle + np.array([1, 1, -1]) * width, particle + np.array([-1, -1, -1]) * width, ] # find all the cells that intersect with the virtual cell ijks = [] weights = [] for b in cell_bounds: w = np.linalg.norm(particle - b) ijk = np.digitize(b, axis) - 1 # print(f"b: {b}, ijk: {ijk}") ijks.append(ijk) weights.append(w) # ensure that the weights sum to 1 weights = np.array(weights) weights /= np.sum(weights) return ijks, weights #### Helper functions for mesh plotting def show_mesh_information(mesh: np.ndarray, name: str): logger.info(f"Mesh information for {name}") logger.info(f"Total mapped mass: {np.sum(mesh):.0f}") logger.info(f"Max cell value: {np.max(mesh)}") logger.info(f"Min cell value: {np.min(mesh)}") logger.info(f"Mean cell value: {np.mean(mesh)}") plot_3d(mesh, name) plot_2d(mesh, name) def plot_3d(mesh: np.ndarray, name: str): fig = plt.figure() fig.suptitle(f"{name} - {mesh.shape}") ax = fig.add_subplot(111, projection='3d') ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis') plt.show() def plot_2d(mesh: np.ndarray, name: str): fig = plt.figure() fig.suptitle(f"{name} - {mesh.shape}") axs = fig.subplots(1, 3) axs[0].imshow(np.sum(mesh, axis=0)) axs[0].set_title("Flattened in x") axs[1].imshow(np.sum(mesh, axis=1)) axs[1].set_title("Flattened in y") axs[2].imshow(np.sum(mesh, axis=2)) axs[2].set_title("Flattened in z") plt.show()