src/iohmm_evac/bootstrap/aggregate.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Aggregate per-cell sweep rows into shift-wise quantile bands."""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np

from iohmm_evac.bootstrap.shift_sweep import ShiftSweepResult
from iohmm_evac.types import FloatArray

__all__ = [
    "BAND_METRICS",
    "BandResult",
    "compute_bands",
    "metric_matrix",
]


BAND_METRICS: tuple[str, ...] = (
    "failed_evacuation_count",
    "peak_enroute_share",
    "total_delay_hours",
    "shelter_overflow_count",
)
"""Metric names that :func:`compute_bands` aggregates."""


@dataclass(frozen=True, slots=True)
class BandResult:
    """Per-shift quantile statistics for each metric.

    ``percentiles`` are the requested percentile cuts (e.g. ``(5, 25, 50, 75,
    95)``). ``shifts`` is the unique sorted shift values. ``bands`` is a mapping
    ``metric -> (n_percentiles, n_shifts)`` array of quantile values.
    """

    percentiles: tuple[int, ...]
    shifts: tuple[int, ...]
    bands: dict[str, FloatArray]

    def quantile(self, metric: str, percentile: int) -> FloatArray:
        """Return the per-shift array for a single ``(metric, percentile)``."""
        if metric not in self.bands:
            msg = f"Unknown metric: {metric!r}. Known: {tuple(self.bands.keys())}"
            raise KeyError(msg)
        if percentile not in self.percentiles:
            msg = f"Percentile {percentile} not computed. Available: {self.percentiles}"
            raise KeyError(msg)
        row = self.percentiles.index(percentile)
        return np.asarray(self.bands[metric][row], dtype=np.float64)


def metric_matrix(sweep_result: ShiftSweepResult, metric: str) -> FloatArray:
    """Return a ``(n_replicates, n_shifts)`` matrix of values for ``metric``.

    Entries with no row at ``(replicate, shift)`` are ``NaN``.
    """
    if metric not in BAND_METRICS:
        msg = f"Unknown metric: {metric!r}. Known: {BAND_METRICS}"
        raise ValueError(msg)
    shifts_sorted = tuple(sorted({int(r.shift) for r in sweep_result.rows}))
    rep_ids_sorted = tuple(sorted({int(r.replicate_id) for r in sweep_result.rows}))
    rep_to_row = {r: i for i, r in enumerate(rep_ids_sorted)}
    shift_to_col = {s: j for j, s in enumerate(shifts_sorted)}
    grid = np.full((len(rep_ids_sorted), len(shifts_sorted)), np.nan, dtype=np.float64)
    for r in sweep_result.rows:
        i = rep_to_row[int(r.replicate_id)]
        j = shift_to_col[int(r.shift)]
        grid[i, j] = float(getattr(r, metric))
    return grid


def compute_bands(
    sweep_result: ShiftSweepResult,
    percentiles: tuple[int, ...] = (5, 25, 50, 75, 95),
) -> BandResult:
    """Per-shift quantile statistics for every metric in :data:`BAND_METRICS`.

    The percentile axis preserves the input order; the shift axis is sorted
    ascending.
    """
    if not percentiles:
        msg = "percentiles must not be empty"
        raise ValueError(msg)
    if any(p < 0 or p > 100 for p in percentiles):
        msg = f"percentiles must lie in [0, 100], got {percentiles}"
        raise ValueError(msg)
    if not sweep_result.rows:
        msg = "sweep_result is empty; nothing to aggregate"
        raise ValueError(msg)

    shifts_sorted = tuple(sorted({int(r.shift) for r in sweep_result.rows}))
    bands: dict[str, FloatArray] = {}
    p_arr = np.asarray(percentiles, dtype=np.float64)
    for metric in BAND_METRICS:
        grid = metric_matrix(sweep_result, metric)  # (R, S)
        bands[metric] = np.asarray(np.nanpercentile(grid, p_arr, axis=0), dtype=np.float64)
    return BandResult(
        percentiles=tuple(int(p) for p in percentiles),
        shifts=shifts_sorted,
        bands=bands,
    )