cleanup of t1, tried integration in t2, units
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1 +1,2 @@ | |||||||
| *.pyc | *.pyc | ||||||
|  | .cache/ | ||||||
							
								
								
									
										2
									
								
								Pipfile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Pipfile
									
									
									
									
									
								
							| @@ -9,7 +9,7 @@ ipython = "*" | |||||||
| jupyter = "*" | jupyter = "*" | ||||||
| matplotlib = "*" | matplotlib = "*" | ||||||
| scipy = "*" | scipy = "*" | ||||||
| spacepy = "*" | astropy = "*" | ||||||
|  |  | ||||||
| [dev-packages] | [dev-packages] | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										1644
									
								
								Pipfile.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										1644
									
								
								Pipfile.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -8,7 +8,21 @@ | |||||||
| - [x] compare with the analytical expectation from Newtons 2nd law | - [x] compare with the analytical expectation from Newtons 2nd law | ||||||
| - [ ] compute the relaxation time | - [ ] compute the relaxation time | ||||||
|  |  | ||||||
| ### Task 2 | ### Task 2 (particle mesh) | ||||||
|  | - [ ] Choose reasonable units | ||||||
|  | - [ ] Implement force computation on mesh | ||||||
|  | - [ ] Find optimal mesh size | ||||||
|  | - [ ] Compare with direct nbody simulation | ||||||
|  | - [ ] Time integration for direct method AND mesh method | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ### Task 2 (tree code) | ||||||
|  | - [ ] Implement force computation with multipole expansion | ||||||
|  |     - [ ] Find optimal grouping criterion | ||||||
|  | - [ ] Compare with direct nbody simulation | ||||||
|  | - [ ] Time integration for direct method AND tree method | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -17,3 +31,7 @@ | |||||||
|  |  | ||||||
| ### Questions | ### Questions | ||||||
| - Procedure for each time step of a mesh simulation? Potential on mesh -> forces on particles -> update particle positions -> new mesh potential? or skip the creation of particles in each time step? | - Procedure for each time step of a mesh simulation? Potential on mesh -> forces on particles -> update particle positions -> new mesh potential? or skip the creation of particles in each time step? | ||||||
|  | - How to represent the time evolution of the system? | ||||||
|  |     - plot total energy vs time | ||||||
|  |     - plot particle positions? | ||||||
|  | - What is the parameter a of the Hernquist model? | ||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -1,9 +1,9 @@ | |||||||
| ## Import all functions in all the files in the current directory | ## Import all functions in all the files in the current directory | ||||||
| # Basic helpers for interacting with the data  | # Basic helpers for interacting with the data  | ||||||
| from .load import * | from .load import * | ||||||
| from .mesh import * |  | ||||||
| from .model import * | from .model import * | ||||||
| from .particles import * | from .particles import * | ||||||
|  | from .units import * | ||||||
|  |  | ||||||
| # Helpers for computing the forces | # Helpers for computing the forces | ||||||
| from .forces_basic import * | from .forces_basic import * | ||||||
|   | |||||||
| @@ -16,7 +16,7 @@ def n_body_forces(particles: np.ndarray, G: float, softening: float = 0): | |||||||
|  |  | ||||||
|     n = particles.shape[0] |     n = particles.shape[0] | ||||||
|     forces = np.zeros((n, 3)) |     forces = np.zeros((n, 3)) | ||||||
|     logger.debug(f"Computing forces for {n} particles using n^2 algorithm (using {softening=})") |     logger.debug(f"Computing forces for {n} particles using n^2 algorithm (using {softening=:.2g})") | ||||||
|  |  | ||||||
|     for i in range(n): |     for i in range(n): | ||||||
|         # the current particle is at x_current |         # the current particle is at x_current | ||||||
|   | |||||||
| @@ -56,19 +56,22 @@ def mesh_forces_v2(particles: np.ndarray, G: float, n_grid: int, mapping: callab | |||||||
|     if particles.shape[1] != 4: |     if particles.shape[1] != 4: | ||||||
|         raise ValueError("Particles array must have 4 columns: x, y, z, m") |         raise ValueError("Particles array must have 4 columns: x, y, z, m") | ||||||
|  |  | ||||||
|     logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh (using mapping={mapping.__name__})") |     logger.debug(f"Computing forces for {particles.shape[0]} particles using mesh [mapping={mapping.__name__}, {n_grid=}]") | ||||||
|  |  | ||||||
|     mesh, axis = to_mesh(particles, n_grid, mapping) |     mesh, axis = to_mesh(particles, n_grid, mapping) | ||||||
|     show_mesh_information(mesh, "Mesh") |     if logger.level >= logging.DEBUG: | ||||||
|  |         show_mesh_information(mesh, "Density mesh") | ||||||
|  |  | ||||||
|     spacing = axis[1] - axis[0] |     spacing = axis[1] - axis[0] | ||||||
|     logger.debug(f"Using mesh spacing: {spacing}") |     logger.debug(f"Using mesh spacing: {spacing}") | ||||||
|     phi = mesh_poisson_v2(mesh, G) |     phi = mesh_poisson_v2(mesh, G) | ||||||
|     logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}") |     logger.debug(f"Got phi with: {phi.shape}, {np.max(phi)}") | ||||||
|     phi_grad = np.stack(np.gradient(phi, spacing), axis=0) |     phi_grad = np.stack(np.gradient(phi, spacing), axis=0) | ||||||
|     show_mesh_information(phi, "Mesh potential") |     if logger.level >= logging.DEBUG: | ||||||
|     show_mesh_information(phi_grad[0], "Mesh potential grad x") |         show_mesh_information(phi, "Potential mesh") | ||||||
|  |         show_mesh_information(phi_grad[0], "Potential gradient") | ||||||
|     logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}") |     logger.debug(f"Got phi_grad with: {phi_grad.shape}, {np.max(phi_grad)}") | ||||||
|  |      | ||||||
|     forces = np.zeros_like(particles[:, :3]) |     forces = np.zeros_like(particles[:, :3]) | ||||||
|     for i, p in enumerate(particles): |     for i, p in enumerate(particles): | ||||||
|         ijk = np.digitize(p, axis) - 1 |         ijk = np.digitize(p, axis) - 1 | ||||||
| @@ -93,7 +96,7 @@ def mesh_poisson_v2(mesh: np.ndarray, G: float) -> np.ndarray: | |||||||
|     return phi |     return phi | ||||||
|  |  | ||||||
|  |  | ||||||
| ## Helper functions for star mapping | #### Helper functions for star mapping | ||||||
| def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.ndarray, np.ndarray]: | def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.ndarray, np.ndarray]: | ||||||
|     """ |     """ | ||||||
|     Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid. |     Maps a list of particles to a of mesh of size n_grid x n_grid x n_grid. | ||||||
| @@ -108,7 +111,7 @@ def to_mesh(particles: np.ndarray, n_grid: int, mapping: callable) -> tuple[np.n | |||||||
|  |  | ||||||
|     for p in particles: |     for p in particles: | ||||||
|         m = p[-1] |         m = p[-1] | ||||||
|         if logger.level <= logging.DEBUG and  m <= 0: |         if logger.level >= logging.DEBUG and  m <= 0: | ||||||
|             logger.warning(f"Particle with negative mass: {p}") |             logger.warning(f"Particle with negative mass: {p}") | ||||||
|         # spread the star onto cells through the shape function, taking into account the mass |         # spread the star onto cells through the shape function, taking into account the mass | ||||||
|         ijks, weights = mapping(p, axis) |         ijks, weights = mapping(p, axis) | ||||||
| @@ -153,20 +156,20 @@ def particle_to_cells_cic(particle, axis, width): | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ## Helper functions for mesh plotting | #### Helper functions for mesh plotting | ||||||
| def show_mesh_information(mesh: np.ndarray, name: str): | def show_mesh_information(mesh: np.ndarray, name: str): | ||||||
|     print(f"Mesh information for {name}") |     logger.info(f"Mesh information for {name}") | ||||||
|     print(f"Total mapped mass: {np.sum(mesh):.0f}") |     logger.info(f"Total mapped mass: {np.sum(mesh):.0f}") | ||||||
|     print(f"Max cell value: {np.max(mesh)}") |     logger.info(f"Max cell value: {np.max(mesh)}") | ||||||
|     print(f"Min cell value: {np.min(mesh)}") |     logger.info(f"Min cell value: {np.min(mesh)}") | ||||||
|     print(f"Mean cell value: {np.mean(mesh)}") |     logger.info(f"Mean cell value: {np.mean(mesh)}") | ||||||
|     plot_3d(mesh, name) |     plot_3d(mesh, name) | ||||||
|     plot_2d(mesh, name) |     plot_2d(mesh, name) | ||||||
|  |  | ||||||
|  |  | ||||||
| def plot_3d(mesh: np.ndarray, name: str): | def plot_3d(mesh: np.ndarray, name: str): | ||||||
|     fig = plt.figure() |     fig = plt.figure() | ||||||
|     fig.suptitle(name) |     fig.suptitle(f"{name} - {mesh.shape}") | ||||||
|     ax = fig.add_subplot(111, projection='3d') |     ax = fig.add_subplot(111, projection='3d') | ||||||
|     ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis') |     ax.scatter(*np.where(mesh), c=mesh[np.where(mesh)], cmap='viridis') | ||||||
|     plt.show() |     plt.show() | ||||||
| @@ -174,7 +177,7 @@ def plot_3d(mesh: np.ndarray, name: str): | |||||||
|  |  | ||||||
| def plot_2d(mesh: np.ndarray, name: str): | def plot_2d(mesh: np.ndarray, name: str): | ||||||
|     fig = plt.figure() |     fig = plt.figure() | ||||||
|     fig.suptitle(name) |     fig.suptitle(f"{name} - {mesh.shape}") | ||||||
|     axs = fig.subplots(1, 3) |     axs = fig.subplots(1, 3) | ||||||
|     axs[0].imshow(np.sum(mesh, axis=0)) |     axs[0].imshow(np.sum(mesh, axis=0)) | ||||||
|     axs[0].set_title("Flattened in x") |     axs[0].set_title("Flattened in x") | ||||||
|   | |||||||
| @@ -16,54 +16,53 @@ def ode_setup(particles: np.ndarray, force_function: callable) -> tuple[np.ndarr | |||||||
|     if particles.shape[1] != 7: |     if particles.shape[1] != 7: | ||||||
|         raise ValueError("Particles array must have 7 columns: x, y, z, vx, vy, vz, m") |         raise ValueError("Particles array must have 7 columns: x, y, z, vx, vy, vz, m") | ||||||
|      |      | ||||||
|     n = particles.shape[0] |     # for scipy integrators we need to flatten array which contains 7 columns for now | ||||||
|     # for scipy integrators we need to flatten the n 3D positions and n 3D velocities |     # we don't really care how we reshape as long as we unflatten consistently afterwards | ||||||
|     y0 = np.zeros(6*n)  |     particles = particles.reshape(-1, copy=False, order='A') | ||||||
|     y0[:3*n] = particles[:, :3].flatten() |     # this is consistent with the unflattening in to_particles()! | ||||||
|     y0[3*n:] = particles[:, 3:6].flatten() |     logger.debug(f"Reshaped 7 columns into {particles.shape=}") | ||||||
|  |  | ||||||
|     # the masses don't change we can define them once |  | ||||||
|     masses = particles[:, 6] |  | ||||||
|     logger.debug(f"Reshaped {particles.shape} to y0 with {y0.shape} and masses with {masses.shape}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
|     def f(y, t): |     def f(y, t): | ||||||
|         """ |         """ | ||||||
|         Computes the right hand side of the ODE system. |         Computes the right hand side of the ODE system. | ||||||
|         The ODE system is linearized around the current positions and velocities. |         The ODE system is linearized around the current positions and velocities. | ||||||
|         """ |         """ | ||||||
|         n = y.size // 6 |         y = to_particles(y) | ||||||
|         logger.debug(f"y with shape {y.shape}") |         # now y has shape (n, 7), with columns x, y, z, vx, vy, vz, m | ||||||
|         # unsqueeze and unstack to extract the positions and velocities |  | ||||||
|         y = y.reshape((2*n, 3)) |  | ||||||
|         x = y[:n, ...] |  | ||||||
|         v = y[n:, ...] |  | ||||||
|         logger.debug(f"Unstacked y into x with shape {x.shape} and v with shape {v.shape}") |  | ||||||
|          |          | ||||||
|         # compute the forces |  | ||||||
|         x_with_m = np.zeros((n, 4)) |         forces = force_function(y[:, [0, 1, 2, -1]]) | ||||||
|         x_with_m[:, :3] = x |  | ||||||
|         x_with_m[:, 3] = masses |  | ||||||
|         forces = force_function(x_with_m) |  | ||||||
|          |          | ||||||
|         # compute the accelerations |         # compute the accelerations | ||||||
|  |         masses = y[:, -1] | ||||||
|         a = forces / masses[:, None] |         a = forces / masses[:, None] | ||||||
|         a.flatten() |  | ||||||
|         # the [:, None] is to force broadcasting in order to divide each row of forces by the corresponding mass |         # the [:, None] is to force broadcasting in order to divide each row of forces by the corresponding mass | ||||||
|  |         # a.flatten() | ||||||
|          |          | ||||||
|         # reshape into a 1D array |         # replace some values in y: | ||||||
|         return np.vstack((v, a)).flatten() |         # the position columns become the velocities | ||||||
|  |         # the velocity columns become the accelerations | ||||||
|  |         y[:, 0:3] = y[:, 3:6] | ||||||
|  |         y[:, 3:6] = a | ||||||
|  |         # the masses remain unchanged | ||||||
|  |  | ||||||
|     return y0, f |         # flatten the array again | ||||||
|  |         y = y.reshape(-1, copy=False, order='A') | ||||||
|  |         return y | ||||||
|  |  | ||||||
|  |     return particles, f | ||||||
|  |  | ||||||
|  |  | ||||||
| def to_particles(y: np.ndarray) -> np.ndarray: | def to_particles(y: np.ndarray) -> np.ndarray: | ||||||
|     """ |     """ | ||||||
|     Converts the 1D array y into a 2D array with the shape (n, 6) where n is the number of particles. |     Converts the 1D array y into a 2D array IN PLACE | ||||||
|     The columns are x, y, z, vx, vy, vz |     The new shape is (n, 7) where n is the number of particles. | ||||||
|  |     The columns are x, y, z, vx, vy, vz, m | ||||||
|     """ |     """ | ||||||
|     n = y.size // 6 |     if y.size % 7 != 0: | ||||||
|     y = y.reshape((2*n, 3)) |         raise ValueError("The array y should be inflatable to 7 columns") | ||||||
|     x = y[:n, ...] |      | ||||||
|     v = y[n:, ...] |     n = y.size // 7 | ||||||
|     return np.hstack((x, v)) |     y = y.reshape((n, 7), copy=False, order='F') | ||||||
|  |     logger.debug(f"Unflattened array into {y.shape=}") | ||||||
|  |     return y | ||||||
| @@ -3,7 +3,7 @@ import logging | |||||||
| logging.basicConfig( | logging.basicConfig( | ||||||
|     ## set logging level |     ## set logging level | ||||||
|     # level=logging.INFO, |     # level=logging.INFO, | ||||||
|     level=logging.DEBUG, |     level=logging.INFO, | ||||||
|     format='%(asctime)s - %(name)s - %(message)s', |     format='%(asctime)s - %(name)s - %(message)s', | ||||||
|     datefmt='%H:%M:%S' |     datefmt='%H:%M:%S' | ||||||
| ) | ) | ||||||
| @@ -13,3 +13,4 @@ logging.basicConfig( | |||||||
| logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING) | logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING) | ||||||
| logging.getLogger('matplotlib.ticker').setLevel(logging.WARNING) | logging.getLogger('matplotlib.ticker').setLevel(logging.WARNING) | ||||||
| logging.getLogger('matplotlib.pyplot').setLevel(logging.WARNING) | logging.getLogger('matplotlib.pyplot').setLevel(logging.WARNING) | ||||||
|  | logging.getLogger('matplotlib.colorbar').setLevel(logging.WARNING) | ||||||
| @@ -1,11 +1,13 @@ | |||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
| M = 5 |  | ||||||
| a = 5 |  | ||||||
|  |  | ||||||
| def model_density_distribution(r_bins: np.ndarray): | def model_density_distribution(r_bins: np.ndarray, M: float = 5, a: float = 5) -> np.ndarray: | ||||||
|     """ |     """ | ||||||
|     Generate a density distribution for a spherical galaxy model, as per the Hernquist model. |     Generate a density distribution for a spherical galaxy model, as per the Hernquist model. | ||||||
|  |     Parameters: | ||||||
|  |     - r_bins: The radial bins to calculate the density distribution for. | ||||||
|  |     - M: The total mass of the galaxy. | ||||||
|  |     - a: The scale radius of the galaxy. | ||||||
|     See https://doi.org/10.1086%2F168845 for more information. |     See https://doi.org/10.1086%2F168845 for more information. | ||||||
|     """ |     """ | ||||||
|     rho = M / (2 * np.pi) * a / (r_bins * (r_bins + a)**3) |     rho = M / (2 * np.pi) * a / (r_bins * (r_bins + a)**3) | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| import numpy as np | import numpy as np | ||||||
| import logging | import logging | ||||||
|  | from . import forces_basic | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -7,20 +8,34 @@ def density_distribution(r_bins: np.ndarray, particles: np.ndarray, ret_error: b | |||||||
|     """ |     """ | ||||||
|     Computes the radial density distribution of a set of particles. |     Computes the radial density distribution of a set of particles. | ||||||
|     Assumes that the particles array has the following columns: x, y, z, m. |     Assumes that the particles array has the following columns: x, y, z, m. | ||||||
|  |     If ret_error is True, it will return the absolute error of the density. | ||||||
|     """ |     """ | ||||||
|     if particles.shape[1] != 4: |     if particles.shape[1] != 4: | ||||||
|         raise ValueError("Particles array must have 4 columns: x, y, z, m") |         raise ValueError("Particles array must have 4 columns: x, y, z, m") | ||||||
|  |  | ||||||
|     m = particles[:, 3] |     m = particles[:, 3] | ||||||
|     r = np.linalg.norm(particles[:, :3], axis=1) |     r = np.linalg.norm(particles[:, :3], axis=1) | ||||||
|     density = [np.sum(m[(r >= r_bins[i]) & (r < r_bins[i + 1])]) for i in range(len(r_bins) - 1)] |  | ||||||
|      |      | ||||||
|     # add the first volume which  should be wrt 0 |     m_shells = np.zeros_like(r_bins) | ||||||
|     volume = 4/3 * np.pi * (r_bins[1:]**3 - r_bins[:-1]**3) |     v_shells = np.zeros_like(r_bins) | ||||||
|     volume = np.insert(volume, 0, 4/3 * np.pi * r_bins[0]**3) |     error_relative = np.zeros_like(r_bins) | ||||||
|     density = r_bins / volume |     r_bins = np.insert(r_bins, 0, 0) | ||||||
|  |  | ||||||
|  |     for i in range(len(r_bins) - 1): | ||||||
|  |         mask = (r >= r_bins[i]) & (r < r_bins[i + 1]) | ||||||
|  |         m_shells[i] = np.sum(m[mask]) | ||||||
|  |         v_shells[i] = 4/3 * np.pi * (r_bins[i + 1]**3 - r_bins[i]**3) | ||||||
|  |         if ret_error: | ||||||
|  |             count = np.count_nonzero(mask) | ||||||
|  |             if count > 0: | ||||||
|  |                 error_relative[i] = 1 / np.sqrt(count) | ||||||
|  |             else: | ||||||
|  |                 error_relative[i] = 0 | ||||||
|  |  | ||||||
|  |     density = m_shells / v_shells | ||||||
|  |  | ||||||
|     if ret_error: |     if ret_error: | ||||||
|         return density, density / np.sqrt(r_bins) |         return density, density * error_relative | ||||||
|     else: |     else: | ||||||
|         return density |         return density | ||||||
|  |  | ||||||
| @@ -72,7 +87,10 @@ def mean_interparticle_distance(particles: np.ndarray): | |||||||
|  |  | ||||||
|     rho = n_half_mass / (4/3 * np.pi * r_half_mass**3) |     rho = n_half_mass / (4/3 * np.pi * r_half_mass**3) | ||||||
|     # the mean distance between particles is the inverse of the density |     # the mean distance between particles is the inverse of the density | ||||||
|     return (1 / rho)**(1/3) |  | ||||||
|  |     epsilon = (1 / rho)**(1/3) | ||||||
|  |     logger.info(f"Found mean interparticle distance: {epsilon}") | ||||||
|  |     return epsilon | ||||||
|     # TODO: check if this is correct |     # TODO: check if this is correct | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -104,21 +122,25 @@ def half_mass_radius(particles: np.ndarray): | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def relaxation_timescale(particles: np.ndarray, G:float) -> float: | def total_energy(particles: np.ndarray): | ||||||
|     """ |     """ | ||||||
|     Computes the relaxation timescale of a set of particles using the velocity at the half mass radius. |     Computes the total energy of a set of particles. | ||||||
|     Assumes that the particles array has the following columns: x, y, z ... |     Assumes that the particles array has the following columns: x, y, z, vx, vy, vz, m. | ||||||
|  |     Uses the approximation that the particles are in a central potential as computed in analytical.py | ||||||
|     """ |     """ | ||||||
|     m_half = np.sum(particles[:, 3]) / 2 # enclosed mass at half mass radius |     if particles.shape[1] != 7: | ||||||
|     r_half = half_mass_radius(particles) |         raise ValueError("Particles array must have 7 columns: x, y, z, vx, vy, vz, m") | ||||||
|     n_half = np.sum(np.linalg.norm(particles[:, :3], axis=1) < r_half) # number of enclosed particles |  | ||||||
|     v_c = np.sqrt(G * m_half / r_half) |  | ||||||
|  |  | ||||||
|     # the crossing time for the half mass system is |     # compute the kinetic energy | ||||||
|     t_c = r_half / v_c |     v = particles[:, 3:6] | ||||||
|     logger.debug(f"Crossing time for half mass system: {t_c}") |     m = particles[:, 6] | ||||||
|  |     ke = 0.5 * np.sum(m * np.linalg.norm(v, axis=1)**2) | ||||||
|  |  | ||||||
|     # the relaxation timescale is t_c * N/(10 * log(N)) |     # # compute the potential energy | ||||||
|     t_rel = t_c * n_half / (10 * np.ln(n_half)) |     # forces = forces_basic.analytical_forces(particles) | ||||||
|  |     # r = np.linalg.norm(particles[:, :3], axis=1) | ||||||
|     return t_rel |     # pe_particles = -forces[:, 0] * particles[:, 0] - forces[:, 1] * particles[:, 1] - forces[:, 2] * particles[:, 2] | ||||||
|  |     # pe = np.sum(pe_particles) | ||||||
|  |     # # TODO: i am pretty sure this is wrong | ||||||
|  |     pe = 0 | ||||||
|  |     return ke + pe | ||||||
|   | |||||||
							
								
								
									
										41
									
								
								nbody/utils/units.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								nbody/utils/units.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | |||||||
|  | import astropy.units as u | ||||||
|  | import numpy as np | ||||||
|  | import logging | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | M_SCALE: int = None | ||||||
|  | R_SCALE: int = None | ||||||
|  |  | ||||||
|  | def seed_scales(r_scale: u.Quantity, m_scale: u.Quantity): | ||||||
|  |     """ | ||||||
|  |     Set the scales for the given simulation. | ||||||
|  |     Parameters: | ||||||
|  |     - r_scale: astropy.units.Quantity with units of length - the characteristic length scale of the simulation. Particle positions are expressed in units of this scale. | ||||||
|  |     - m_scale: astropy.units.Quantity with units of mass - the characteristic mass scale of the simulation. Particle masses are expressed in units of this scale. | ||||||
|  |     """ | ||||||
|  |     global M_SCALE, R_SCALE | ||||||
|  |     M_SCALE = m_scale | ||||||
|  |     R_SCALE = r_scale | ||||||
|  |     logger.info(f"Set scales: M_SCALE = {M_SCALE}, R_SCALE = {R_SCALE}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def apply_units(columns: np.array, quantity: str): | ||||||
|  |     if quantity == "mass": | ||||||
|  |         return columns * M_SCALE | ||||||
|  |     elif quantity == "position": | ||||||
|  |         return columns * R_SCALE | ||||||
|  |     elif quantity == "velocity": | ||||||
|  |         return columns * R_SCALE / u.s | ||||||
|  |     elif quantity == "time": | ||||||
|  |         return columns * u.s | ||||||
|  |     elif quantity == "acceleration": | ||||||
|  |         return columns * R_SCALE / u.s**2 | ||||||
|  |     elif quantity == "force": | ||||||
|  |         return columns * M_SCALE * R_SCALE / u.s**2 | ||||||
|  |     elif quantity == "volume": | ||||||
|  |         return columns * R_SCALE**3 | ||||||
|  |     elif quantity == "density": | ||||||
|  |         return columns * M_SCALE / R_SCALE**3 | ||||||
|  |     else: | ||||||
|  |         raise ValueError(f"Unknown quantity: {quantity}") | ||||||
		Reference in New Issue
	
	Block a user