52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
from pathlib import Path
|
|
import numpy as np
|
|
import timeit
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def cached_forces(cache_path: Path, particles: np.ndarray, force_function:callable, func_kwargs: dict):
|
|
"""
|
|
Tries to load the forces from a cache file. If that fails, computes the forces using the provided function.
|
|
"""
|
|
cache_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
n_particles = particles.shape[0]
|
|
|
|
kwargs_str = kwargs_to_str(func_kwargs)
|
|
|
|
force_cache = cache_path / f"forces__{force_function.__name__}__n_{n_particles}__kwargs_{kwargs_str}.npy"
|
|
time_cache = cache_path / f"time__{force_function.__name__}__n_{n_particles}__kwargs_{kwargs_str}.npy"
|
|
|
|
if force_cache.exists() and time_cache.exists():
|
|
force = np.load(force_cache)
|
|
logger.info(f"Loaded forces from {force_cache}")
|
|
time = np.load(time_cache)
|
|
else:
|
|
force = force_function(particles, **func_kwargs)
|
|
np.save(force_cache, force)
|
|
time = 0
|
|
logger.info(f"Timing {force_function.__name__} for {n_particles} particles")
|
|
time = timeit.timeit(lambda: force_function(particles, **func_kwargs), number=10)
|
|
np.save(time_cache, time)
|
|
|
|
return force, time
|
|
|
|
|
|
|
|
def kwargs_to_str(kwargs: dict):
|
|
"""
|
|
Converts a dictionary of keyword arguments to a string.
|
|
"""
|
|
base_str = ""
|
|
for k, v in kwargs.items():
|
|
if type(v) == float:
|
|
base_str += f"{k}_{v:.3f}"
|
|
elif callable(v):
|
|
base_str += f"{k}_{v.__name__}"
|
|
else:
|
|
base_str += f"{k}_{v}"
|
|
base_str += "__"
|
|
|
|
return base_str
|