53 lines
1.6 KiB
Python

from pathlib import Path
import numpy as np
import timeit
import logging
logger = logging.getLogger(__name__)
def cached_forces(cache_path: Path, particles: np.ndarray, force_function:callable, func_kwargs: dict):
"""
Tries to load the forces from a cache file. If that fails, computes the forces using the provided function.
"""
cache_path.mkdir(parents=True, exist_ok=True)
n_particles = particles.shape[0]
kwargs_str = kwargs_to_str(func_kwargs)
force_cache = cache_path / f"forces__{force_function.__name__}__n_{n_particles}__kwargs_{kwargs_str}.npy"
time_cache = cache_path / f"time__{force_function.__name__}__n_{n_particles}__kwargs_{kwargs_str}.npy"
if force_cache.exists() and time_cache.exists():
force = np.load(force_cache)
logger.info(f"Loaded forces from {force_cache}")
time = np.load(time_cache)
else:
force = force_function(particles, **func_kwargs)
np.save(force_cache, force)
time = 0
logger.info(f"Timing {force_function.__name__} for {n_particles} particles")
time = timeit.timeit(lambda: force_function(particles, **func_kwargs), number=10)
np.save(time_cache, time)
return force, time
def kwargs_to_str(kwargs: dict):
"""
Converts a dictionary of keyword arguments to a string.
"""
base_str = ""
for k, v in kwargs.items():
print(type(v))
if type(v) == float:
base_str += f"{k}_{v:.3f}"
elif type(v) == callable:
base_str += f"{k}_{v.__name__}"
else:
base_str += f"{k}_{v}"
base_str += "__"
return base_str