cleanup for the presentation

This commit is contained in:
2025-02-01 18:33:07 +01:00
parent 7db6480e51
commit 37a2687ffe
8 changed files with 904 additions and 638 deletions

View File

@@ -29,7 +29,8 @@ def n_body_forces(particles: np.ndarray, G: float, softening: float = 0):
# add softening to the denominator
r_adjusted = r**2 + softening**2
# usually with a square root: r' = sqrt(r^2 + softening^2) and then cubed, but we combine that below
# usually with a square root: r' = sqrt(r^2 + softening^2)
# and then cubed, but we combine that below
# the numerator is tricky:
# m is a list of scalars and r_vec is a list of vectors (2D array)

View File

@@ -81,9 +81,12 @@ def to_particles_3d(y: np.ndarray) -> np.ndarray:
return y
def runge_kutta_4(y0 : np.ndarray, t : float, f, dt : float):
k1 = f(y0, t)
k2 = f(y0 + k1/2 * dt, t + dt/2)
k3 = f(y0 + k2/2 * dt, t + dt/2)
k4 = f(y0 + k3 * dt, t + dt)
return y0 + (k1 + 2*k2 + 2*k3 + k4)/6 * dt
def runge_kutta_4(y: np.ndarray, t: float, f: callable, dt: float):
"""
Runge-Kutta 4th order integrator.
"""
k1 = f(y, t)
k2 = f(y + k1/2 * dt, t + dt/2)
k3 = f(y + k2/2 * dt, t + dt/2)
k4 = f(y + k3 * dt, t + dt)
return y + (k1 + 2*k2 + 2*k3 + k4)/6 * dt

View File

@@ -1,7 +1,10 @@
import numpy as np
import matplotlib.pyplot as plt
from astropy import units as u
import logging
logger = logging.getLogger(__name__)
import matplotlib.pyplot as plt
def density_distribution(r_bins: np.ndarray, particles: np.ndarray, ret_error: bool = False):
"""
@@ -146,7 +149,8 @@ def total_energy(particles: np.ndarray):
return ke + pe
def particles_plot_3d(particles: np.ndarray, title: str = "Particle distribution (3D)"):
def particles_plot_3d(positions: np.ndarray, masses: np.ndarray, title: str = "Particle distribution (3D)"):
"""
Plots a 3D scatter plot of a set of particles.
Assumes that the particles array has the shape:
@@ -154,48 +158,54 @@ def particles_plot_3d(particles: np.ndarray, title: str = "Particle distribution
- or 7 columns: x, y, z, vx, vy, vz, m
Colormap is the mass of the particles.
"""
if particles.shape[1] == 4:
x, y, z, m = particles[:, 0], particles[:, 1], particles[:, 2], particles[:, 3]
c = m
elif particles.shape[1] == 7:
x, y, z, m = particles[:, 0], particles[:, 1], particles[:, 2], particles[:, 6]
c = m
else:
raise ValueError("Particles array must have 4 or 7 columns")
x, y, z = positions[:, 0], positions[:, 1], positions[:, 2]
fig = plt.figure()
plt.title(title)
fig.suptitle(title)
ax = fig.add_subplot(111, projection='3d')
ax.scatter(particles[:,0], particles[:,1], particles[:,2], cmap='viridis', c=particles[:,3])
sc = ax.scatter(x, y, z, cmap='coolwarm', c=masses)
cbar = plt.colorbar(sc, ax=ax, pad=0.1)
try:
cbar.set_label(f'Mass [{masses.unit:latex}]')
ax.set_xlabel(f'x [{x.unit:latex}]')
ax.set_ylabel(f'y [{x.unit:latex}]')
ax.set_zlabel(f'z [{x.unit:latex}]')
except AttributeError:
cbar.set_label('Mass')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.show()
logger.debug("3D scatter plot with mass colormap")
def particles_plot_2d(particles: np.ndarray, title: str = "Flattened distribution (along z)"):
def particles_plot_2d(particles: np.ndarray, title: str = "Flattened distribution (along z)", ax = None):
"""
Plots a 2 colormap of a set of particles, flattened in the z direction.
Assumes that the particles array has the shape:
- either 4 columns: x, y, z, m
- either 3 columns: x, y, z
- or 4 columns: x, y, z, m
- or 7 columns: x, y, z, vx, vy, vz, m
"""
if particles.shape[1] == 4:
if particles.shape[1] == 3:
x, y, z = particles[:, 0], particles[:, 1], particles[:, 2]
elif particles.shape[1] == 4:
x, y, z, m = particles[:, 0], particles[:, 1], particles[:, 2], particles[:, 3]
c = m
elif particles.shape[1] == 7:
x, y, z, m = particles[:, 0], particles[:, 1], particles[:, 2], particles[:, 6]
c = m
else:
raise ValueError("Particles array must have 4 or 7 columns")
# plt.figure()
# plt.title(title)
# plt.scatter(x, y, c=range(particles.shape[0]))
# plt.colorbar()
# plt.show()
raise ValueError("Particles array must have 3, 4 or 7 columns")
# or as a discrete heatmap
plt.figure()
plt.title(title)
plt.hist2d(x, y, bins=100, cmap='viridis')
plt.colorbar()
plt.show()
if ax is None:
plt.figure()
plt.title(title)
plt.hist2d(x, y, bins=100, cmap='coolwarm')
cbar = plt.colorbar()
cbar.set_label(f'Particle count')
plt.show()
else:
ax.hist2d(x, y, bins=100, cmap='coolwarm')