cleanup for the presentation
This commit is contained in:
@@ -29,7 +29,8 @@ def n_body_forces(particles: np.ndarray, G: float, softening: float = 0):
|
||||
|
||||
# add softening to the denominator
|
||||
r_adjusted = r**2 + softening**2
|
||||
# usually with a square root: r' = sqrt(r^2 + softening^2) and then cubed, but we combine that below
|
||||
# usually with a square root: r' = sqrt(r^2 + softening^2)
|
||||
# and then cubed, but we combine that below
|
||||
|
||||
# the numerator is tricky:
|
||||
# m is a list of scalars and r_vec is a list of vectors (2D array)
|
||||
|
@@ -81,9 +81,12 @@ def to_particles_3d(y: np.ndarray) -> np.ndarray:
|
||||
return y
|
||||
|
||||
|
||||
def runge_kutta_4(y0 : np.ndarray, t : float, f, dt : float):
|
||||
k1 = f(y0, t)
|
||||
k2 = f(y0 + k1/2 * dt, t + dt/2)
|
||||
k3 = f(y0 + k2/2 * dt, t + dt/2)
|
||||
k4 = f(y0 + k3 * dt, t + dt)
|
||||
return y0 + (k1 + 2*k2 + 2*k3 + k4)/6 * dt
|
||||
def runge_kutta_4(y: np.ndarray, t: float, f: callable, dt: float):
|
||||
"""
|
||||
Runge-Kutta 4th order integrator.
|
||||
"""
|
||||
k1 = f(y, t)
|
||||
k2 = f(y + k1/2 * dt, t + dt/2)
|
||||
k3 = f(y + k2/2 * dt, t + dt/2)
|
||||
k4 = f(y + k3 * dt, t + dt)
|
||||
return y + (k1 + 2*k2 + 2*k3 + k4)/6 * dt
|
||||
|
@@ -1,7 +1,10 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from astropy import units as u
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def density_distribution(r_bins: np.ndarray, particles: np.ndarray, ret_error: bool = False):
|
||||
"""
|
||||
@@ -146,7 +149,8 @@ def total_energy(particles: np.ndarray):
|
||||
return ke + pe
|
||||
|
||||
|
||||
def particles_plot_3d(particles: np.ndarray, title: str = "Particle distribution (3D)"):
|
||||
|
||||
def particles_plot_3d(positions: np.ndarray, masses: np.ndarray, title: str = "Particle distribution (3D)"):
|
||||
"""
|
||||
Plots a 3D scatter plot of a set of particles.
|
||||
Assumes that the particles array has the shape:
|
||||
@@ -154,48 +158,54 @@ def particles_plot_3d(particles: np.ndarray, title: str = "Particle distribution
|
||||
- or 7 columns: x, y, z, vx, vy, vz, m
|
||||
Colormap is the mass of the particles.
|
||||
"""
|
||||
if particles.shape[1] == 4:
|
||||
x, y, z, m = particles[:, 0], particles[:, 1], particles[:, 2], particles[:, 3]
|
||||
c = m
|
||||
elif particles.shape[1] == 7:
|
||||
x, y, z, m = particles[:, 0], particles[:, 1], particles[:, 2], particles[:, 6]
|
||||
c = m
|
||||
else:
|
||||
raise ValueError("Particles array must have 4 or 7 columns")
|
||||
|
||||
x, y, z = positions[:, 0], positions[:, 1], positions[:, 2]
|
||||
|
||||
fig = plt.figure()
|
||||
plt.title(title)
|
||||
fig.suptitle(title)
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
ax.scatter(particles[:,0], particles[:,1], particles[:,2], cmap='viridis', c=particles[:,3])
|
||||
sc = ax.scatter(x, y, z, cmap='coolwarm', c=masses)
|
||||
cbar = plt.colorbar(sc, ax=ax, pad=0.1)
|
||||
|
||||
try:
|
||||
cbar.set_label(f'Mass [{masses.unit:latex}]')
|
||||
ax.set_xlabel(f'x [{x.unit:latex}]')
|
||||
ax.set_ylabel(f'y [{x.unit:latex}]')
|
||||
ax.set_zlabel(f'z [{x.unit:latex}]')
|
||||
except AttributeError:
|
||||
cbar.set_label('Mass')
|
||||
ax.set_xlabel('x')
|
||||
ax.set_ylabel('y')
|
||||
ax.set_zlabel('z')
|
||||
|
||||
plt.show()
|
||||
logger.debug("3D scatter plot with mass colormap")
|
||||
|
||||
|
||||
def particles_plot_2d(particles: np.ndarray, title: str = "Flattened distribution (along z)"):
|
||||
|
||||
def particles_plot_2d(particles: np.ndarray, title: str = "Flattened distribution (along z)", ax = None):
|
||||
"""
|
||||
Plots a 2 colormap of a set of particles, flattened in the z direction.
|
||||
Assumes that the particles array has the shape:
|
||||
- either 4 columns: x, y, z, m
|
||||
- either 3 columns: x, y, z
|
||||
- or 4 columns: x, y, z, m
|
||||
- or 7 columns: x, y, z, vx, vy, vz, m
|
||||
"""
|
||||
if particles.shape[1] == 4:
|
||||
if particles.shape[1] == 3:
|
||||
x, y, z = particles[:, 0], particles[:, 1], particles[:, 2]
|
||||
elif particles.shape[1] == 4:
|
||||
x, y, z, m = particles[:, 0], particles[:, 1], particles[:, 2], particles[:, 3]
|
||||
c = m
|
||||
elif particles.shape[1] == 7:
|
||||
x, y, z, m = particles[:, 0], particles[:, 1], particles[:, 2], particles[:, 6]
|
||||
c = m
|
||||
else:
|
||||
raise ValueError("Particles array must have 4 or 7 columns")
|
||||
|
||||
# plt.figure()
|
||||
# plt.title(title)
|
||||
# plt.scatter(x, y, c=range(particles.shape[0]))
|
||||
# plt.colorbar()
|
||||
# plt.show()
|
||||
raise ValueError("Particles array must have 3, 4 or 7 columns")
|
||||
|
||||
# or as a discrete heatmap
|
||||
plt.figure()
|
||||
plt.title(title)
|
||||
plt.hist2d(x, y, bins=100, cmap='viridis')
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
if ax is None:
|
||||
plt.figure()
|
||||
plt.title(title)
|
||||
plt.hist2d(x, y, bins=100, cmap='coolwarm')
|
||||
cbar = plt.colorbar()
|
||||
cbar.set_label(f'Particle count')
|
||||
|
||||
plt.show()
|
||||
else:
|
||||
ax.hist2d(x, y, bins=100, cmap='coolwarm')
|
||||
|
Reference in New Issue
Block a user