34 lines
1.2 KiB
Python

from pathlib import Path
import numpy as np
import timeit
import logging
logger = logging.getLogger(__name__)
def cached_forces(cache_path: Path, particles: np.ndarray, force_function:callable, func_kwargs: dict):
"""
Tries to load the forces from a cache file. If that fails, computes the forces using the provided function.
"""
cache_path.mkdir(parents=True, exist_ok=True)
n_particles = particles.shape[0]
kwargs_str = "_".join([f"{k}_{v}" for k, v in func_kwargs.items()])
force_cache = cache_path / f"forces__{force_function.__name__}__n_{n_particles}__kwargs_{kwargs_str}.npy"
time_cache = cache_path / f"time__{force_function.__name__}__n_{n_particles}__kwargs_{kwargs_str}.npy"
if force_cache.exists() and time_cache.exists():
force = np.load(force_cache)
logger.info(f"Loaded forces from {force_cache}")
time = np.load(time_cache)
else:
force = force_function(particles, **func_kwargs)
np.save(force_cache, force)
time = 0
np.info(f"Timing {force_function.__name__} for {n_particles} particles")
time = timeit.timeit(lambda: force_function(particles, **func_kwargs), number=10)
np.save(time_cache, time)
return force, time