cleanup of t1, tried integration in t2, units
This commit is contained in:
parent
77a4959fe2
commit
9e856bc854
3
.gitignore
vendored
3
.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
*.pyc
|
*.pyc
|
||||||
|
.cache/
|
2
Pipfile
2
Pipfile
@ -9,7 +9,7 @@ ipython = "*"
|
|||||||
jupyter = "*"
|
jupyter = "*"
|
||||||
matplotlib = "*"
|
matplotlib = "*"
|
||||||
scipy = "*"
|
scipy = "*"
|
||||||
spacepy = "*"
|
astropy = "*"
|
||||||
|
|
||||||
[dev-packages]
|
[dev-packages]
|
||||||
|
|
||||||
|
1644
Pipfile.lock
generated
1644
Pipfile.lock
generated
File diff suppressed because it is too large
Load Diff
@ -8,7 +8,21 @@
|
|||||||
- [x] compare with the analytical expectation from Newtons 2nd law
|
- [x] compare with the analytical expectation from Newtons 2nd law
|
||||||
- [ ] compute the relaxation time
|
- [ ] compute the relaxation time
|
||||||
|
|
||||||
### Task 2
|
### Task 2 (particle mesh)
|
||||||
|
- [ ] Choose reasonable units
|
||||||
|
- [ ] Implement force computation on mesh
|
||||||
|
- [ ] Find optimal mesh size
|
||||||
|
- [ ] Compare with direct nbody simulation
|
||||||
|
- [ ] Time integration for direct method AND mesh method
|
||||||
|
|
||||||
|
|
||||||
|
### Task 2 (tree code)
|
||||||
|
- [ ] Implement force computation with multipole expansion
|
||||||
|
- [ ] Find optimal grouping criterion
|
||||||
|
- [ ] Compare with direct nbody simulation
|
||||||
|
- [ ] Time integration for direct method AND tree method
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -16,4 +30,8 @@
|
|||||||
|
|
||||||
|
|
||||||
### Questions
|
### Questions
|
||||||
- Procedure for each time step of a mesh simulation? Potential on mesh -> forces on particles -> update particle positions -> new mesh potential? or skip the creation of particles in each time step?
|
- Procedure for each time step of a mesh simulation? Potential on mesh -> forces on particles -> update particle positions -> new mesh potential? or skip the creation of particles in each time step?
|
||||||
|
- How to represent the time evolution of the system?
|
||||||
|
- plot total energy vs time
|
||||||
|
- plot particle positions?
|
||||||
|
- What is the parameter a of the Hernquist model?
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,9 +1,9 @@
|
|||||||
## Import all functions in all the files in the current directory
|
## Import all functions in all the files in the current directory
|
||||||
# Basic helpers for interacting with the data
|
# Basic helpers for interacting with the data
|
||||||
from .load import *
|
from .load import *
|
||||||
from .mesh import *
|
|
||||||
from .model import *
|
from .model import *
|
||||||
from .particles import *
|
from .particles import *
|
||||||
|
from .units import *
|
||||||
|
|
||||||
# Helpers for computing the forces
|
# Helpers for computing the forces
|
||||||
from .forces_basic import *
|
from .forces_basic import *
|
||||||
|
@ -16,7 +16,7 @@ def n_body_forces(particles: np.ndarray, G: float, softening: float = 0):
|
|||||||
|
|
||||||
n = particles.shape[0]
|
n = particles.shape[0]
|
||||||
forces = np.zeros((n, 3))
|
forces = np.zeros((n, 3))
|
||||||
logger.debug(f"Computing forces for {n} particles using n^2 algorithm (using {softening=})")
|
logger.debug(f"Computing forces for {n} particles using n^2 algorithm (using {softening=:.2g})")
|
||||||
|
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
# the current particle is at x_current
|
# the current particle is at x_current
|
||||||
|
@ -56,19 +56,22 @@ def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callab
|
|||||||
if particles.shape[1] != 4:
|
if particles.shape[1] != 4:
|
||||||
raise ValueError("Particles array must have 4 columns: x, y, z, m")
|
raise ValueError("Particles array must have 4 columns: x, y, z, m")
|
||||||
|
|
||||||
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh (using mapping={mapping.__name__})")
|
logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]")
|
||||||
|
|
||||||
mesh, axis = to_mesh(particles, n_grid, mapping)
|
mesh, axis = to_mesh(particles, n_grid, mapping)
|
||||||
show_mesh_information(mesh, "Mesh")
|
if logger.level >= logging.DEBUG:
|
||||||
|
show_mesh_information(mesh, "Density mesh")
|
||||||
|
|
||||||
spacing = axis[1] - axis[0]
|
spacing = axis[1] - axis[0]
|
||||||
logger.debug(f"Using mesh spacing: {spacing}")
|
logger.debug(f"Using mesh spacing: {spacing}")
|
||||||
phi = mesh_poisson_v2(mesh, G)
|
phi = mesh_poisson_v2(mesh, G)
|
||||||
logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}")
|
logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}")
|
||||||
phi_grad = np.stack(np.gradient(phi, spacing), axis=0)
|
phi_grad = np.stack(np.gradient(phi, spacing), axis=0)
|
||||||
show_mesh_information(phi, "Mesh potential")
|
if logger.level >= logging.DEBUG:
|
||||||
show_mesh_information(phi_grad[0], "Mesh potential grad x")
|
show_mesh_information(phi, "Potential mesh")
|
||||||
|
show_mesh_information(phi_grad[0], "Potential gradient")
|
||||||
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
|
logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}")
|
||||||
|
|
||||||
forces = np.zeros_like(particles[:, :3])
|
forces = np.zeros_like(particles[:, :3])
|
||||||
for i, p in enumerate(particles):
|
for i, p in enumerate(particles):
|
||||||
ijk = np.digitize(p, axis) - 1
|
ijk = np.digitize(p, axis) - 1
|
||||||
@ -93,7 +96,7 @@ def mesh_poisson_v2(mesh: np.ndarray, G: float) -> np.ndarray:
|
|||||||
return phi
|
return phi
|
||||||
|
|
||||||
|
|
||||||
## Helper functions for star mapping
|
#### Helper functions for star mapping
|
||||||
def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.ndarray, np.ndarray]:
|
def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid.
|
Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid.
|
||||||
@ -108,7 +111,7 @@ def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.n
|
|||||||
|
|
||||||
for p in particles:
|
for p in particles:
|
||||||
m = p[-1]
|
m = p[-1]
|
||||||
if logger.level <= logging.DEBUG and m <= 0:
|
if logger.level >= logging.DEBUG and m <= 0:
|
||||||
logger.warning(f"Particle with negative mass: {p}")
|
logger.warning(f"Particle with negative mass: {p}")
|
||||||
# spread the star onto cells through the shape function, taking into account the mass
|
# spread the star onto cells through the shape function, taking into account the mass
|
||||||
ijks, weights = mapping(p, axis)
|
ijks, weights = mapping(p, axis)
|
||||||
@ -153,20 +156,20 @@ def particle_to_cells_cic(particle, axis, width):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Helper functions for mesh plotting
|
#### Helper functions for mesh plotting
|
||||||
def show_mesh_information(mesh: np.ndarray, name: str):
|
def show_mesh_information(mesh: np.ndarray, name: str):
|
||||||
print(f"Mesh information for {name}")
|
logger.info(f"Mesh information for {name}")
|
||||||
print(f"Total mapped mass: {np.sum(mesh):.0f}")
|
logger.info(f"Total mapped mass: {np.sum(mesh):.0f}")
|
||||||
print(f"Max cell value: {np.max(mesh)}")
|
logger.info(f"Max cell value: {np.max(mesh)}")
|
||||||
print(f"Min cell value: {np.min(mesh)}")
|
logger.info(f"Min cell value: {np.min(mesh)}")
|
||||||
print(f"Mean cell value: {np.mean(mesh)}")
|
logger.info(f"Mean cell value: {np.mean(mesh)}")
|
||||||
plot_3d(mesh, name)
|
plot_3d(mesh, name)
|
||||||
plot_2d(mesh, name)
|
plot_2d(mesh, name)
|
||||||
|
|
||||||
|
|
||||||
def plot_3d(mesh: np.ndarray, name: str):
|
def plot_3d(mesh: np.ndarray, name: str):
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
fig.suptitle(name)
|
fig.suptitle(f"{name} - {mesh.shape}")
|
||||||
ax = fig.add_subplot(111, projection='3d')
|
ax = fig.add_subplot(111, projection='3d')
|
||||||
ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis')
|
ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis')
|
||||||
plt.show()
|
plt.show()
|
||||||
@ -174,7 +177,7 @@ def plot_3d(mesh: np.ndarray, name: str):
|
|||||||
|
|
||||||
def plot_2d(mesh: np.ndarray, name: str):
|
def plot_2d(mesh: np.ndarray, name: str):
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
fig.suptitle(name)
|
fig.suptitle(f"{name} - {mesh.shape}")
|
||||||
axs = fig.subplots(1, 3)
|
axs = fig.subplots(1, 3)
|
||||||
axs[0].imshow(np.sum(mesh, axis=0))
|
axs[0].imshow(np.sum(mesh, axis=0))
|
||||||
axs[0].set_title("Flattened in x")
|
axs[0].set_title("Flattened in x")
|
||||||
@ -182,4 +185,4 @@ def plot_2d(mesh: np.ndarray, name: str):
|
|||||||
axs[1].set_title("Flattened in y")
|
axs[1].set_title("Flattened in y")
|
||||||
axs[2].imshow(np.sum(mesh, axis=2))
|
axs[2].imshow(np.sum(mesh, axis=2))
|
||||||
axs[2].set_title("Flattened in z")
|
axs[2].set_title("Flattened in z")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
@ -16,54 +16,53 @@ def ode_setup(particles: np.ndarray, force_function: callable) -> tuple[np.ndarr
|
|||||||
if particles.shape[1] != 7:
|
if particles.shape[1] != 7:
|
||||||
raise ValueError("Particles array must have 7 columns: x, y, z, vx, vy, vz, m")
|
raise ValueError("Particles array must have 7 columns: x, y, z, vx, vy, vz, m")
|
||||||
|
|
||||||
n = particles.shape[0]
|
# for scipy integrators we need to flatten array which contains 7 columns for now
|
||||||
# for scipy integrators we need to flatten the n 3D positions and n 3D velocities
|
# we don't really care how we reshape as long as we unflatten consistently afterwards
|
||||||
y0 = np.zeros(6*n)
|
particles = particles.reshape(-1, copy=False, order='A')
|
||||||
y0[:3*n] = particles[:, :3].flatten()
|
# this is consistent with the unflattening in to_particles()!
|
||||||
y0[3*n:] = particles[:, 3:6].flatten()
|
logger.debug(f"Reshaped 7 columns into {particles.shape=}")
|
||||||
|
|
||||||
# the masses don't change we can define them once
|
|
||||||
masses = particles[:, 6]
|
|
||||||
logger.debug(f"Reshaped {particles.shape} to y0 with {y0.shape} and masses with {masses.shape}")
|
|
||||||
|
|
||||||
|
|
||||||
def f(y, t):
|
def f(y, t):
|
||||||
"""
|
"""
|
||||||
Computes the right hand side of the ODE system.
|
Computes the right hand side of the ODE system.
|
||||||
The ODE system is linearized around the current positions and velocities.
|
The ODE system is linearized around the current positions and velocities.
|
||||||
"""
|
"""
|
||||||
n = y.size // 6
|
y = to_particles(y)
|
||||||
logger.debug(f"y with shape {y.shape}")
|
# now y has shape (n, 7), with columns x, y, z, vx, vy, vz, m
|
||||||
# unsqueeze and unstack to extract the positions and velocities
|
|
||||||
y = y.reshape((2*n, 3))
|
|
||||||
x = y[:n, ...]
|
|
||||||
v = y[n:, ...]
|
|
||||||
logger.debug(f"Unstacked y into x with shape {x.shape} and v with shape {v.shape}")
|
|
||||||
|
|
||||||
# compute the forces
|
|
||||||
x_with_m = np.zeros((n, 4))
|
forces = force_function(y[:, [0, 1, 2, -1]])
|
||||||
x_with_m[:, :3] = x
|
|
||||||
x_with_m[:, 3] = masses
|
|
||||||
forces = force_function(x_with_m)
|
|
||||||
|
|
||||||
# compute the accelerations
|
# compute the accelerations
|
||||||
|
masses = y[:, -1]
|
||||||
a = forces / masses[:, None]
|
a = forces / masses[:, None]
|
||||||
a.flatten()
|
|
||||||
# the [:, None] is to force broadcasting in order to divide each row of forces by the corresponding mass
|
# the [:, None] is to force broadcasting in order to divide each row of forces by the corresponding mass
|
||||||
|
# a.flatten()
|
||||||
|
|
||||||
# reshape into a 1D array
|
# replace some values in y:
|
||||||
return np.vstack((v, a)).flatten()
|
# the position columns become the velocities
|
||||||
|
# the velocity columns become the accelerations
|
||||||
return y0, f
|
y[:, 0:3] = y[:, 3:6]
|
||||||
|
y[:, 3:6] = a
|
||||||
|
# the masses remain unchanged
|
||||||
|
|
||||||
|
# flatten the array again
|
||||||
|
y = y.reshape(-1, copy=False, order='A')
|
||||||
|
return y
|
||||||
|
|
||||||
|
return particles, f
|
||||||
|
|
||||||
|
|
||||||
def to_particles(y: np.ndarray) -> np.ndarray:
|
def to_particles(y: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Converts the 1D array y into a 2D array with the shape (n, 6) where n is the number of particles.
|
Converts the 1D array y into a 2D array IN PLACE
|
||||||
The columns are x, y, z, vx, vy, vz
|
The new shape is (n, 7) where n is the number of particles.
|
||||||
|
The columns are x, y, z, vx, vy, vz, m
|
||||||
"""
|
"""
|
||||||
n = y.size // 6
|
if y.size % 7 != 0:
|
||||||
y = y.reshape((2*n, 3))
|
raise ValueError("The array y should be inflatable to 7 columns")
|
||||||
x = y[:n, ...]
|
|
||||||
v = y[n:, ...]
|
n = y.size // 7
|
||||||
return np.hstack((x, v))
|
y = y.reshape((n, 7), copy=False, order='F')
|
||||||
|
logger.debug(f"Unflattened array into {y.shape=}")
|
||||||
|
return y
|
@ -3,7 +3,7 @@ import logging
|
|||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
## set logging level
|
## set logging level
|
||||||
# level=logging.INFO,
|
# level=logging.INFO,
|
||||||
level=logging.DEBUG,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(name)s - %(message)s',
|
format='%(asctime)s - %(name)s - %(message)s',
|
||||||
datefmt='%H:%M:%S'
|
datefmt='%H:%M:%S'
|
||||||
)
|
)
|
||||||
@ -13,3 +13,4 @@ logging.basicConfig(
|
|||||||
logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING)
|
logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING)
|
||||||
logging.getLogger('matplotlib.ticker').setLevel(logging.WARNING)
|
logging.getLogger('matplotlib.ticker').setLevel(logging.WARNING)
|
||||||
logging.getLogger('matplotlib.pyplot').setLevel(logging.WARNING)
|
logging.getLogger('matplotlib.pyplot').setLevel(logging.WARNING)
|
||||||
|
logging.getLogger('matplotlib.colorbar').setLevel(logging.WARNING)
|
@ -1,12 +1,14 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
M = 5
|
|
||||||
a = 5
|
|
||||||
|
|
||||||
def model_density_distribution(r_bins: np.ndarray):
|
def model_density_distribution(r_bins: np.ndarray, M: float = 5, a: float = 5) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Generate a density distribution for a spherical galaxy model, as per the Hernquist model.
|
Generate a density distribution for a spherical galaxy model, as per the Hernquist model.
|
||||||
|
Parameters:
|
||||||
|
- r_bins: The radial bins to calculate the density distribution for.
|
||||||
|
- M: The total mass of the galaxy.
|
||||||
|
- a: The scale radius of the galaxy.
|
||||||
See https://doi.org/10.1086%2F168845 for more information.
|
See https://doi.org/10.1086%2F168845 for more information.
|
||||||
"""
|
"""
|
||||||
rho = M / (2 * np.pi) * a / (r_bins * (r_bins + a)**3)
|
rho = M / (2 * np.pi) * a / (r_bins * (r_bins + a)**3)
|
||||||
return rho
|
return rho
|
@ -1,5 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
|
from . import forces_basic
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -7,20 +8,34 @@ def density_distribution(r_bins: np.ndarray, particles: np.ndarray, ret_error: b
|
|||||||
"""
|
"""
|
||||||
Computes the radial density distribution of a set of particles.
|
Computes the radial density distribution of a set of particles.
|
||||||
Assumes that the particles array has the following columns: x, y, z, m.
|
Assumes that the particles array has the following columns: x, y, z, m.
|
||||||
|
If ret_error is True, it will return the absolute error of the density.
|
||||||
"""
|
"""
|
||||||
if particles.shape[1] != 4:
|
if particles.shape[1] != 4:
|
||||||
raise ValueError("Particles array must have 4 columns: x, y, z, m")
|
raise ValueError("Particles array must have 4 columns: x, y, z, m")
|
||||||
|
|
||||||
m = particles[:, 3]
|
m = particles[:, 3]
|
||||||
r = np.linalg.norm(particles[:, :3], axis=1)
|
r = np.linalg.norm(particles[:, :3], axis=1)
|
||||||
density = [np.sum(m[(r >= r_bins[i]) & (r < r_bins[i + 1])]) for i in range(len(r_bins) - 1)]
|
|
||||||
|
m_shells = np.zeros_like(r_bins)
|
||||||
|
v_shells = np.zeros_like(r_bins)
|
||||||
|
error_relative = np.zeros_like(r_bins)
|
||||||
|
r_bins = np.insert(r_bins, 0, 0)
|
||||||
|
|
||||||
|
for i in range(len(r_bins) - 1):
|
||||||
|
mask = (r >= r_bins[i]) & (r < r_bins[i + 1])
|
||||||
|
m_shells[i] = np.sum(m[mask])
|
||||||
|
v_shells[i] = 4/3 * np.pi * (r_bins[i + 1]**3 - r_bins[i]**3)
|
||||||
|
if ret_error:
|
||||||
|
count = np.count_nonzero(mask)
|
||||||
|
if count > 0:
|
||||||
|
error_relative[i] = 1 / np.sqrt(count)
|
||||||
|
else:
|
||||||
|
error_relative[i] = 0
|
||||||
|
|
||||||
|
density = m_shells / v_shells
|
||||||
|
|
||||||
# add the first volume which should be wrt 0
|
|
||||||
volume = 4/3 * np.pi * (r_bins[1:]**3 - r_bins[:-1]**3)
|
|
||||||
volume = np.insert(volume, 0, 4/3 * np.pi * r_bins[0]**3)
|
|
||||||
density = r_bins / volume
|
|
||||||
if ret_error:
|
if ret_error:
|
||||||
return density, density / np.sqrt(r_bins)
|
return density, density * error_relative
|
||||||
else:
|
else:
|
||||||
return density
|
return density
|
||||||
|
|
||||||
@ -72,7 +87,10 @@ def mean_interparticle_distance(particles: np.ndarray):
|
|||||||
|
|
||||||
rho = n_half_mass / (4/3 * np.pi * r_half_mass**3)
|
rho = n_half_mass / (4/3 * np.pi * r_half_mass**3)
|
||||||
# the mean distance between particles is the inverse of the density
|
# the mean distance between particles is the inverse of the density
|
||||||
return (1 / rho)**(1/3)
|
|
||||||
|
epsilon = (1 / rho)**(1/3)
|
||||||
|
logger.info(f"Found mean interparticle distance: {epsilon}")
|
||||||
|
return epsilon
|
||||||
# TODO: check if this is correct
|
# TODO: check if this is correct
|
||||||
|
|
||||||
|
|
||||||
@ -104,21 +122,25 @@ def half_mass_radius(particles: np.ndarray):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def relaxation_timescale(particles: np.ndarray, G:float) -> float:
|
def total_energy(particles: np.ndarray):
|
||||||
"""
|
"""
|
||||||
Computes the relaxation timescale of a set of particles using the velocity at the half mass radius.
|
Computes the total energy of a set of particles.
|
||||||
Assumes that the particles array has the following columns: x, y, z ...
|
Assumes that the particles array has the following columns: x, y, z, vx, vy, vz, m.
|
||||||
|
Uses the approximation that the particles are in a central potential as computed in analytical.py
|
||||||
"""
|
"""
|
||||||
m_half = np.sum(particles[:, 3]) / 2 # enclosed mass at half mass radius
|
if particles.shape[1] != 7:
|
||||||
r_half = half_mass_radius(particles)
|
raise ValueError("Particles array must have 7 columns: x, y, z, vx, vy, vz, m")
|
||||||
n_half = np.sum(np.linalg.norm(particles[:, :3], axis=1) < r_half) # number of enclosed particles
|
|
||||||
v_c = np.sqrt(G * m_half / r_half)
|
|
||||||
|
|
||||||
# the crossing time for the half mass system is
|
# compute the kinetic energy
|
||||||
t_c = r_half / v_c
|
v = particles[:, 3:6]
|
||||||
logger.debug(f"Crossing time for half mass system: {t_c}")
|
m = particles[:, 6]
|
||||||
|
ke = 0.5 * np.sum(m * np.linalg.norm(v, axis=1)**2)
|
||||||
|
|
||||||
# the relaxation timescale is t_c * N/(10 * log(N))
|
# # compute the potential energy
|
||||||
t_rel = t_c * n_half / (10 * np.ln(n_half))
|
# forces = forces_basic.analytical_forces(particles)
|
||||||
|
# r = np.linalg.norm(particles[:, :3], axis=1)
|
||||||
return t_rel
|
# pe_particles = -forces[:, 0] * particles[:, 0] - forces[:, 1] * particles[:, 1] - forces[:, 2] * particles[:, 2]
|
||||||
|
# pe = np.sum(pe_particles)
|
||||||
|
# # TODO: i am pretty sure this is wrong
|
||||||
|
pe = 0
|
||||||
|
return ke + pe
|
||||||
|
41
nbody/utils/units.py
Normal file
41
nbody/utils/units.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import astropy.units as u
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
M_SCALE: int = None
|
||||||
|
R_SCALE: int = None
|
||||||
|
|
||||||
|
def seed_scales(r_scale: u.Quantity, m_scale: u.Quantity):
|
||||||
|
"""
|
||||||
|
Set the scales for the given simulation.
|
||||||
|
Parameters:
|
||||||
|
- r_scale: astropy.units.Quantity with units of length - the characteristic length scale of the simulation. Particle positions are expressed in units of this scale.
|
||||||
|
- m_scale: astropy.units.Quantity with units of mass - the characteristic mass scale of the simulation. Particle masses are expressed in units of this scale.
|
||||||
|
"""
|
||||||
|
global M_SCALE, R_SCALE
|
||||||
|
M_SCALE = m_scale
|
||||||
|
R_SCALE = r_scale
|
||||||
|
logger.info(f"Set scales: M_SCALE = {M_SCALE}, R_SCALE = {R_SCALE}")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_units(columns: np.array, quantity: str):
|
||||||
|
if quantity == "mass":
|
||||||
|
return columns * M_SCALE
|
||||||
|
elif quantity == "position":
|
||||||
|
return columns * R_SCALE
|
||||||
|
elif quantity == "velocity":
|
||||||
|
return columns * R_SCALE / u.s
|
||||||
|
elif quantity == "time":
|
||||||
|
return columns * u.s
|
||||||
|
elif quantity == "acceleration":
|
||||||
|
return columns * R_SCALE / u.s**2
|
||||||
|
elif quantity == "force":
|
||||||
|
return columns * M_SCALE * R_SCALE / u.s**2
|
||||||
|
elif quantity == "volume":
|
||||||
|
return columns * R_SCALE**3
|
||||||
|
elif quantity == "density":
|
||||||
|
return columns * M_SCALE / R_SCALE**3
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown quantity: {quantity}")
|
Loading…
x
Reference in New Issue
Block a user