cleanup for the presentation
This commit is contained in:
		| @@ -29,7 +29,8 @@ def n_body_forces(particles: np.ndarray, G: float, softening: float = 0): | ||||
|  | ||||
|         # add softening to the denominator | ||||
|         r_adjusted = r**2 + softening**2 | ||||
|         # usually with a square root: r' = sqrt(r^2 + softening^2) and then cubed, but we combine that below | ||||
|         # usually with a square root: r' = sqrt(r^2 + softening^2) | ||||
|         # and then cubed, but we combine that below | ||||
|  | ||||
|         # the numerator is tricky: | ||||
|         # m is a list of scalars and r_vec is a list of vectors (2D array) | ||||
|   | ||||
| @@ -81,9 +81,12 @@ def to_particles_3d(y: np.ndarray) -> np.ndarray: | ||||
|     return y | ||||
|  | ||||
|  | ||||
| def runge_kutta_4(y0 : np.ndarray, t : float, f, dt : float): | ||||
|     k1 = f(y0, t) | ||||
|     k2 = f(y0 + k1/2 * dt, t + dt/2) | ||||
|     k3 = f(y0 + k2/2 * dt, t + dt/2) | ||||
|     k4 = f(y0 + k3 * dt, t + dt) | ||||
|     return y0 + (k1 + 2*k2 + 2*k3 + k4)/6 * dt | ||||
| def runge_kutta_4(y: np.ndarray, t: float, f: callable, dt: float): | ||||
|     """ | ||||
|     Runge-Kutta 4th order integrator. | ||||
|     """ | ||||
|     k1 = f(y, t) | ||||
|     k2 = f(y + k1/2 * dt, t + dt/2) | ||||
|     k3 = f(y + k2/2 * dt, t + dt/2) | ||||
|     k4 = f(y + k3 * dt, t + dt) | ||||
|     return y + (k1 + 2*k2 + 2*k3 + k4)/6 * dt | ||||
|   | ||||
| @@ -1,7 +1,10 @@ | ||||
| import numpy as np | ||||
| import matplotlib.pyplot as plt | ||||
| from astropy import units as u | ||||
|  | ||||
| import logging | ||||
| logger = logging.getLogger(__name__) | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
|  | ||||
| def density_distribution(r_bins: np.ndarray, particles: np.ndarray, ret_error: bool = False): | ||||
|     """ | ||||
| @@ -146,7 +149,8 @@ def total_energy(particles: np.ndarray): | ||||
|     return ke + pe | ||||
|  | ||||
|  | ||||
| def particles_plot_3d(particles: np.ndarray, title: str = "Particle distribution (3D)"): | ||||
|  | ||||
| def particles_plot_3d(positions: np.ndarray, masses: np.ndarray, title: str = "Particle distribution (3D)"): | ||||
|     """ | ||||
|     Plots a 3D scatter plot of a set of particles. | ||||
|     Assumes that the particles array has the shape: | ||||
| @@ -154,48 +158,54 @@ def particles_plot_3d(particles: np.ndarray, title: str = "Particle distribution | ||||
|     - or 7 columns: x, y, z, vx, vy, vz, m | ||||
|     Colormap is the mass of the particles. | ||||
|     """ | ||||
|     if particles.shape[1] == 4: | ||||
|         x, y, z, m = particles[:, 0], particles[:, 1], particles[:, 2], particles[:, 3] | ||||
|         c = m | ||||
|     elif particles.shape[1] == 7: | ||||
|         x, y, z, m = particles[:, 0], particles[:, 1], particles[:, 2], particles[:, 6] | ||||
|         c = m | ||||
|     else: | ||||
|         raise ValueError("Particles array must have 4 or 7 columns") | ||||
|      | ||||
|     x, y, z = positions[:, 0], positions[:, 1], positions[:, 2] | ||||
|  | ||||
|     fig = plt.figure() | ||||
|     plt.title(title) | ||||
|     fig.suptitle(title) | ||||
|     ax = fig.add_subplot(111, projection='3d') | ||||
|     ax.scatter(particles[:,0], particles[:,1], particles[:,2], cmap='viridis', c=particles[:,3]) | ||||
|     sc = ax.scatter(x, y, z, cmap='coolwarm', c=masses) | ||||
|     cbar = plt.colorbar(sc, ax=ax, pad=0.1) | ||||
|  | ||||
|     try: | ||||
|         cbar.set_label(f'Mass [{masses.unit:latex}]') | ||||
|         ax.set_xlabel(f'x [{x.unit:latex}]') | ||||
|         ax.set_ylabel(f'y [{x.unit:latex}]') | ||||
|         ax.set_zlabel(f'z [{x.unit:latex}]') | ||||
|     except AttributeError: | ||||
|         cbar.set_label('Mass') | ||||
|         ax.set_xlabel('x') | ||||
|         ax.set_ylabel('y') | ||||
|         ax.set_zlabel('z') | ||||
|  | ||||
|     plt.show() | ||||
|     logger.debug("3D scatter plot with mass colormap") | ||||
|  | ||||
|  | ||||
| def particles_plot_2d(particles: np.ndarray, title: str = "Flattened distribution (along z)"): | ||||
|  | ||||
| def particles_plot_2d(particles: np.ndarray, title: str = "Flattened distribution (along z)", ax = None): | ||||
|     """ | ||||
|     Plots a 2 colormap of a set of particles, flattened in the z direction. | ||||
|     Assumes that the particles array has the shape: | ||||
|     - either 4 columns: x, y, z, m | ||||
|     - either 3 columns: x, y, z | ||||
|     - or 4 columns: x, y, z, m | ||||
|     - or 7 columns: x, y, z, vx, vy, vz, m | ||||
|     """ | ||||
|     if particles.shape[1] == 4: | ||||
|     if particles.shape[1] == 3: | ||||
|         x, y, z = particles[:, 0], particles[:, 1], particles[:, 2] | ||||
|     elif particles.shape[1] == 4: | ||||
|         x, y, z, m = particles[:, 0], particles[:, 1], particles[:, 2], particles[:, 3] | ||||
|         c = m | ||||
|     elif particles.shape[1] == 7: | ||||
|         x, y, z, m = particles[:, 0], particles[:, 1], particles[:, 2], particles[:, 6] | ||||
|         c = m | ||||
|     else: | ||||
|         raise ValueError("Particles array must have 4 or 7 columns") | ||||
|      | ||||
|     # plt.figure() | ||||
|     # plt.title(title) | ||||
|     # plt.scatter(x, y, c=range(particles.shape[0])) | ||||
|     # plt.colorbar() | ||||
|     # plt.show() | ||||
|         raise ValueError("Particles array must have 3, 4 or 7 columns") | ||||
|  | ||||
|     # or as a discrete heatmap | ||||
|     plt.figure() | ||||
|     plt.title(title) | ||||
|     plt.hist2d(x, y, bins=100, cmap='viridis') | ||||
|     plt.colorbar() | ||||
|     plt.show() | ||||
|     if ax is None: | ||||
|         plt.figure() | ||||
|         plt.title(title) | ||||
|         plt.hist2d(x, y, bins=100, cmap='coolwarm') | ||||
|         cbar = plt.colorbar() | ||||
|         cbar.set_label(f'Particle count') | ||||
|  | ||||
|         plt.show() | ||||
|     else: | ||||
|         ax.hist2d(x, y, bins=100, cmap='coolwarm') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user