{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# automatically reflect changes in imported modules\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import astropy.units as u\n",
    "\n",
    "import utils\n",
    "import utils.logging_config\n",
    "utils.logging_config.set_log_level(\"info\")\n",
    "import logging\n",
    "logger = logging.getLogger(\"task2 (mesh)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_ROOT = Path('data')\n",
    "# DATA_NAME = 'data0.txt'\n",
    "DATA_NAME = 'data1.txt'\n",
    "# DATA_NAME = 'data0_noise.txt'\n",
    "DATA_NAME = 'data1_noise.txt'\n",
    "NBINS = 30\n",
    "CACHE_ROOT = Path('.cache')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "points, columns = utils.load_data(DATA_ROOT / DATA_NAME)\n",
    "logger.debug(f\"Fetched {points.shape[0]} points, columns: {columns}\")\n",
    "# points = points[1:100, ...]\n",
    "# points = points[::5]\n",
    "# TODO remove\n",
    "# reorder the columns to match the expected order (x, y, z, mass)\n",
    "particles = points[:, [2, 3, 4, 1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.particles_plot_3d(particles)\n",
    "# Note: colormap corresponds to the mass of the particles\n",
    "utils.particles_plot_2d(particles)\n",
    "# Note: colormap corresponds to the order of the particles in the array\n",
    "\n",
    "## Also consider the velocity distribution\n",
    "velocities = points[:, [5, 6, 7]]\n",
    "r = np.linalg.norm(particles[..., :3], axis=-1)\n",
    "v = np.linalg.norm(velocities, axis=-1)\n",
    "plt.figure()\n",
    "plt.plot(r, v, 'o')\n",
    "plt.xlabel('r')\n",
    "plt.ylabel('v')\n",
    "plt.show()\n",
    "\n",
    "## Check the velocity direction\n",
    "radial_velocities = np.zeros_like(v)\n",
    "circular_velocities = np.zeros_like(v)\n",
    "for i in range(particles.shape[0]):\n",
    "    if r[i] > 0:\n",
    "        radial_velocities[i] = np.abs(np.dot(velocities[i], particles[i, :3]) / r[i])\n",
    "        circular_velocities[i] = np.linalg.norm(np.cross(particles[i, :3], velocities[i])) / r[i]\n",
    "    else:\n",
    "        radial_velocities[i] = 0\n",
    "        circular_velocities[i] = 0\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(r, radial_velocities, 'o', label=\"Radial velocity\")\n",
    "plt.plot(r, circular_velocities, 'o', label=\"Circular velocity\")\n",
    "plt.xlabel('r')\n",
    "plt.ylabel('v')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### System characteristics\n",
    "- central black hole with mass $\\sim 10 \\%$\n",
    "- particles orbit circularly in the equatorial plane\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Choice of units\n",
    "Recap of the particle properties:\n",
    "- $\\sim 10^4$ particles\n",
    "- around 1 black hole (10% of the mass)\n",
    "\n",
    "$\\implies$ ???"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set G = 1\n",
    "G = 1\n",
    "\n",
    "# from the particle number we can estimate the total (stellar) mass, excluding the BH\n",
    "M_TOT = 1e4 * u.M_sun\n",
    "# the radius aound the black hole follows from ??? # TODO\n",
    "R_TOT = 1 * u.pc\n",
    "\n",
    "# Rescale the units of the particles - considering only the orbiting stars\n",
    "M_particles = particles[:,3].sum() - 1\n",
    "R_particles = np.max(np.linalg.norm(particles[:, :3], axis=1))\n",
    "\n",
    "logger.info(f\"Considering a globular cluster - total mass of stars: {M_particles}, maximum radius of particles: {R_particles}\")\n",
    "m_scale = M_TOT / M_particles\n",
    "r_scale = R_TOT / R_particles\n",
    "utils.seed_scales(r_scale, m_scale)\n",
    "logger.info(f\"Black hole mass: {utils.apply_units(particles[0, -1], \"mass\"):.2g}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Direct N body force computation\n",
    "epsilon = utils.mean_interparticle_distance(particles)\n",
    "\n",
    "epsilon_range = np.logspace(-2, 2, 5)\n",
    "epsilon_range = np.logspace(-1, 1, 3)\n",
    "n_squared_forces = []\n",
    "\n",
    "SAVE_FORCES = False\n",
    "\n",
    "for e in epsilon_range:\n",
    "    n_particles = particles.shape[0]\n",
    "    cache_file = CACHE_ROOT / f\"n_squared_forces__n_{n_particles}__softening_multiplier_{e:.0f}.npy\"\n",
    "    try:\n",
    "        f = np.load(cache_file)\n",
    "        logger.info(f\"Loaded forces from {cache_file}\")\n",
    "    except FileNotFoundError:\n",
    "        f = utils.n_body_forces(particles, G, e * epsilon)\n",
    "        if SAVE_FORCES:\n",
    "            np.save(cache_file, f)\n",
    "            logger.debug(f\"Saved forces to {cache_file}\")\n",
    "    n_squared_forces.append(f)\n",
    "\n",
    "### Mesh based force computation\n",
    "mesh_size_range = [10, 20, 50, 100, 150, 200]\n",
    "mesh_size_range = [20, 75, 50, 100]\n",
    "# TODO add uneven numbers\n",
    "mapping = utils.particle_to_cells_nn\n",
    "\n",
    "mesh_forces = []\n",
    "for mesh_size in mesh_size_range:\n",
    "    cache_file = CACHE_ROOT / f\"mesh_forces__n_{n_particles}__mesh_size_{mesh_size}__mapping_{mapping.__name__}.npy\"\n",
    "    try:\n",
    "        f = np.load(cache_file)\n",
    "        logger.info(f\"Loaded forces from {cache_file}\")\n",
    "    except FileNotFoundError:\n",
    "        f = utils.mesh_forces_v2(particles, G, mesh_size, mapping)\n",
    "        if SAVE_FORCES:\n",
    "            np.save(cache_file, f)\n",
    "            logger.debug(f\"Saved forces to {cache_file}\")\n",
    "    mesh_forces.append(f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Compare the mesh computation with the direct summation\n",
    "r = np.linalg.norm(particles[:,:3], axis=1)\n",
    "\n",
    "plt.figure()\n",
    "plt.title('Radial force dependence')\n",
    "plt.xscale('log')\n",
    "plt.yscale('log')\n",
    "plt.xlabel('$r$')\n",
    "plt.ylabel('$|F(r)|$')\n",
    "\n",
    "# many of the particles have the same distance from the origin, so we skip some of them\n",
    "SKIP_N = 20\n",
    "\n",
    "for f, e in zip(n_squared_forces, epsilon_range):\n",
    "    # remove the black hole:\n",
    "    plt.plot(r[1::SKIP_N], np.linalg.norm(f, axis=1)[1::SKIP_N], 'o', label=f\"$N^2$ - {e:.1g} * $\\\\epsilon$\", alpha=0.3)\n",
    "for f, s in zip(mesh_forces, mesh_size_range):\n",
    "    # remove the black hole:\n",
    "    plt.plot(r[1::SKIP_N], np.linalg.norm(f, axis=1)[1::SKIP_N], 'x', label=f\"Mesh - N={s}\")\n",
    "\n",
    "plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
    "plt.show()\n",
    "\n",
    "\n",
    "# TODO: compare computation time\n",
    "\"\"\"\n",
    "plt.figure()\n",
    "plt.xscale('log')\n",
    "plt.yscale('log')\n",
    "f = n_squared_forces[1]\n",
    "f_val = np.linalg.norm(f, axis=1)\n",
    "logger.debug(f\"F n square - {f_val.max()=:.2g} - {f_val.min()=:.2g}\")\n",
    "plt.plot(r, f_val, 'o', label='N^2')\n",
    "f = mesh_forces[1]\n",
    "f_val = np.linalg.norm(f, axis=1)\n",
    "logger.debug(f\"F mesh - {f_val.max()=:.2g} - {f_val.min()=:.2g}\")\n",
    "logger.debug(f\"Mesh size: {mesh_size_range[1]}\")\n",
    "plt.plot(r, f_val, 'x', label='Mesh')\n",
    "plt.ylim([5e-4, 1e2])\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Discussion\n",
    "- Using the baseline of $N^2 + 1 \\varepsilon$ softening we can see that already a 20 x 20 x 20 grid provides good accuracy but the mapping breaks down at small distances (dip)\n",
    "- Larger grids are more stable, especially at small distances => 50 x 50 x 50 already seems to be a good choice\n",
    "- very large grids show overdiscretization => noisy data even for the non-noisy particle distributions\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Time integration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.integrate as spi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the particles in the format [x, y, z, vx, vy, vz, mass]\n",
    "p0 = points[:, [2, 3, 4, 5, 6, 7, 1]]\n",
    "\n",
    "logger.info(f\"Considering {p0.shape[0]} particles\")\n",
    "logger.info(f\"Total mass: {np.sum(p0[:,6])}\")\n",
    "\n",
    "if logger.level <= logging.DEBUG:\n",
    "    # assert that the ODE reshaping is consistent\n",
    "    p0_ref = p0.copy()\n",
    "    y0, _ = utils.ode_setup(p0, None)\n",
    "    logger.debug(y0[0:7])\n",
    "    p0_reconstructed = utils.to_particles(y0)\n",
    "    logger.debug(f\"{p0_ref[0]} -> {p0_reconstructed[0]}\")\n",
    "    logger.debug(f\"{p0_ref[1]} -> {p0_reconstructed[1]}\")\n",
    "\n",
    "    assert np.allclose(p0_ref, p0_reconstructed)\n",
    "    logger.debug(\"Consistency check passed\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def integrate(method: str, force_function: callable, p0: np.ndarray, t_range: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Integrate the gravitational movement of the particles, using the specified method\n",
    "    - method: the integration method to use (\"scipy\" or \"rk4\")\n",
    "    - force_function: the function that computes the forces acting on the particles\n",
    "    - p0: the initial conditions of the particles (n, 7) array, unflattened\n",
    "    - t_range: the time range to integrate over\n",
    "    Returns: the integrated positions and velocities of the particles in a 'flattened' array (time_steps, nx7)\n",
    "    \"\"\"\n",
    "    y0, y_prime = utils.ode_setup(p0, force_function)\n",
    "    \n",
    "    if method == \"scipy\":\n",
    "        sol = spi.odeint(y_prime, y0, t_range, rtol=1e-2)\n",
    "    elif method == \"rk4\":\n",
    "        sol = np.zeros((t_range.shape[0], y0.shape[0]))\n",
    "        sol[0] = y0\n",
    "        dt = t_range[1] - t_range[0]\n",
    "        for i in range(1, t_range.shape[0]):\n",
    "            t = t_range[i]\n",
    "            sol[i,...] = utils.runge_kutta_4(sol[i-1], t, y_prime, dt)\n",
    "\n",
    "\n",
    "    logger.info(f\"Integration done, shape: {sol.shape}\")\n",
    "    return sol\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Determine the integration timesteps\n",
    "# let's first compute the crossing time\n",
    "v = np.linalg.norm(particles[:, 3:6], axis=1)\n",
    "v_mean = np.mean(v)\n",
    "# a timestep should result in a small displacement, wrt. to the mean interparticle distance\n",
    "r_inter = utils.mean_interparticle_distance(particles)\n",
    "\n",
    "dt = r_inter / v_mean * 1e-3\n",
    "logger.info(f\"Mean velocity: {v_mean}, timestep: {dt}\")\n",
    "\n",
    "if np.isnan(dt):\n",
    "    raise ValueError(\"Invalid timestep\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Integration setup - use the n_squared forces for a few timesteps only, to see if the orbits are stable\n",
    "t_orbit = 2 * np.pi * r_inter / v_mean\n",
    "n_steps = int(t_orbit / dt * 5)\n",
    "n_steps = 10\n",
    "t_range = np.arange(0, n_steps*dt, dt)\n",
    "assert t_range.shape[0] == n_steps\n",
    "logger.info(f\"Integration range: {t_range[0]} -> {t_range[-1]}, n_steps: {n_steps}\")\n",
    "\n",
    "# The force function can be interchanged\n",
    "epsilon = utils.mean_interparticle_distance(particles)\n",
    "# epsilon = 0.01\n",
    "\n",
    "force_function = lambda x: utils.n_body_forces(x, G, epsilon)\n",
    "# force_function = lambda x: 0\n",
    "# force_function = lambda x: utils.n_body_forces_basic(x, G, epsilon)\n",
    "# force_function = lambda x: utils.analytical_forces(x)\n",
    "# force_function = lambda x: utils.mesh_forces_v2(x, G, 75, utils.particle_to_cells_nn)\n",
    "\n",
    "sol = integrate(\"rk4\", force_function, p0, t_range)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Plot a fixed number of states\n",
    "SHOW_N_STATES = 10 # should be even\n",
    "# skip some particles to make the plot more readable\n",
    "SHOW_NTH_PARTICLE = 5\n",
    "\n",
    "particles_in_time = utils.to_particles_3d(sol)\n",
    "\n",
    "\"\"\"\n",
    "## Show the particles in 3D\n",
    "fig, axs = plt.subplots(2, SHOW_N_STATES//2, subplot_kw={'projection': '3d'})\n",
    "\n",
    "for i, ax in enumerate(axs.flat):\n",
    "    nth = int(particles_in_time.shape[0] / SHOW_N_STATES) * i\n",
    "    p = particles_in_time[nth][::SHOW_NTH_PARTICLE]\n",
    "    ax.scatter(p[:,0], p[:,1], p[:,2], cmap='viridis', c=range(p.shape[0]))\n",
    "    ax.set_title(f\"t={t_range[nth]:.2g} (step {nth})\")\n",
    "\n",
    "fig.set_size_inches(18, 12)\n",
    "plt.show()\n",
    "\n",
    "## Show the 2D orbits of selected particles\n",
    "fig, axs = plt.subplots(2, SHOW_N_STATES//2, sharex=True, sharey=True)\n",
    "\n",
    "for i, ax in enumerate(axs.flat):\n",
    "    nth = int(particles_in_time.shape[0] / SHOW_N_STATES) * i\n",
    "    x = particles_in_time[:,i,0]\n",
    "    y = particles_in_time[:,i,1]\n",
    "    ax.scatter(x, y, c=range(t_range.size))\n",
    "    ax.set_title(f\"particle {nth}\")\n",
    "\n",
    "    ax.set_xlabel('x')\n",
    "    ax.set_ylabel('y')\n",
    "\n",
    "# Share x and y axis\n",
    "for ax in axs.flat:\n",
    "    ax.label_outer()\n",
    "\n",
    "fig.set_size_inches(18, 12)\n",
    "plt.show()\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "## Show the evolution of the density distrubtion\n",
    "for i in range(particles_in_time.shape[0]):\n",
    "    p = particles_in_time[i]\n",
    "    utils.particles_plot_2d(p, title=f\"t={t_range[i]:.2g}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Plot some key quantities of the system as a whole\n",
    "# sol has the shape (n_steps, n_particles*6) where the first 3*n are the positions and the last 3*n are the velocities\n",
    "\n",
    "# kinetic energy\n",
    "energies = np.zeros(n_steps)\n",
    "for i in range(n_steps):\n",
    "    p = particles_in_time[i]\n",
    "    ke_per_particle = 0.5 * p[:, 6] * np.linalg.norm(p[:,3:6], axis=1)**2\n",
    "    # logger.debug(f\"KE: {ke_per_particle.shape}\")\n",
    "    k_e = np.sum(ke_per_particle)\n",
    "    energies[i] = k_e\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(t_range, energies)\n",
    "plt.title('Kinetic energy')\n",
    "plt.xlabel('Integration time')\n",
    "plt.ylabel('Energy')\n",
    "plt.show()\n",
    "\n",
    "\n",
    "# radial extrema of the particles - disk surface\n",
    "r_mins = np.zeros(n_steps)\n",
    "r_maxs = np.zeros(n_steps)\n",
    "for i in range(n_steps):\n",
    "    p = particles_in_time[i][1:,...] # remove the black hole\n",
    "    r = np.linalg.norm(p[:,:3], axis=1)\n",
    "    r_mins[i] = np.min(r)\n",
    "    r_maxs[i] = np.max(r)\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(t_range, r_mins, label='r_min')\n",
    "plt.plot(t_range, r_maxs, label='r_max')\n",
    "plt.title('Radial extrema')\n",
    "plt.xlabel('Integration time')\n",
    "plt.ylabel('r')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Full PM solver\n",
    "We now have all the tools to implement the full PM solver:\n",
    "- force computation using mesh\n",
    "- integrator with RK4\n",
    "- estimate for good timesteps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Integration setup - use the n_squared forces for a few timesteps only, to see if the orbits are stable\n",
    "t_orbit = 2 * np.pi * r_inter / v_mean\n",
    "n_steps = int(t_orbit / dt * 30)\n",
    "t_range = np.arange(0, n_steps*dt, dt)\n",
    "logger.info(f\"Integration range: {t_range[0]} -> {t_range[-1]}, n_steps: {n_steps}\")\n",
    "\n",
    "\n",
    "mesh_size = 50 # as per the previous discussion\n",
    "force_function = lambda x: utils.mesh_forces_v2(x, G, mesh_size, utils.particle_to_cells_nn)\n",
    "\n",
    "sol = integrate(\"rk4\", force_function, p0, t_range)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Show some results"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "projects-X-9bmgL6",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}