# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Aggregate per-cell sweep rows into shift-wise quantile bands."""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from iohmm_evac.bootstrap.shift_sweep import ShiftSweepResult
from iohmm_evac.types import FloatArray
__all__ = [
"BAND_METRICS",
"BandResult",
"compute_bands",
"metric_matrix",
]
BAND_METRICS: tuple[str, ...] = (
"failed_evacuation_count",
"peak_enroute_share",
"total_delay_hours",
"shelter_overflow_count",
)
"""Metric names that :func:`compute_bands` aggregates."""
@dataclass(frozen=True, slots=True)
class BandResult:
"""Per-shift quantile statistics for each metric.
``percentiles`` are the requested percentile cuts (e.g. ``(5, 25, 50, 75,
95)``). ``shifts`` is the unique sorted shift values. ``bands`` is a mapping
``metric -> (n_percentiles, n_shifts)`` array of quantile values.
"""
percentiles: tuple[int, ...]
shifts: tuple[int, ...]
bands: dict[str, FloatArray]
def quantile(self, metric: str, percentile: int) -> FloatArray:
"""Return the per-shift array for a single ``(metric, percentile)``."""
if metric not in self.bands:
msg = f"Unknown metric: {metric!r}. Known: {tuple(self.bands.keys())}"
raise KeyError(msg)
if percentile not in self.percentiles:
msg = f"Percentile {percentile} not computed. Available: {self.percentiles}"
raise KeyError(msg)
row = self.percentiles.index(percentile)
return np.asarray(self.bands[metric][row], dtype=np.float64)
def metric_matrix(sweep_result: ShiftSweepResult, metric: str) -> FloatArray:
"""Return a ``(n_replicates, n_shifts)`` matrix of values for ``metric``.
Entries with no row at ``(replicate, shift)`` are ``NaN``.
"""
if metric not in BAND_METRICS:
msg = f"Unknown metric: {metric!r}. Known: {BAND_METRICS}"
raise ValueError(msg)
shifts_sorted = tuple(sorted({int(r.shift) for r in sweep_result.rows}))
rep_ids_sorted = tuple(sorted({int(r.replicate_id) for r in sweep_result.rows}))
rep_to_row = {r: i for i, r in enumerate(rep_ids_sorted)}
shift_to_col = {s: j for j, s in enumerate(shifts_sorted)}
grid = np.full((len(rep_ids_sorted), len(shifts_sorted)), np.nan, dtype=np.float64)
for r in sweep_result.rows:
i = rep_to_row[int(r.replicate_id)]
j = shift_to_col[int(r.shift)]
grid[i, j] = float(getattr(r, metric))
return grid
def compute_bands(
sweep_result: ShiftSweepResult,
percentiles: tuple[int, ...] = (5, 25, 50, 75, 95),
) -> BandResult:
"""Per-shift quantile statistics for every metric in :data:`BAND_METRICS`.
The percentile axis preserves the input order; the shift axis is sorted
ascending.
"""
if not percentiles:
msg = "percentiles must not be empty"
raise ValueError(msg)
if any(p < 0 or p > 100 for p in percentiles):
msg = f"percentiles must lie in [0, 100], got {percentiles}"
raise ValueError(msg)
if not sweep_result.rows:
msg = "sweep_result is empty; nothing to aggregate"
raise ValueError(msg)
shifts_sorted = tuple(sorted({int(r.shift) for r in sweep_result.rows}))
bands: dict[str, FloatArray] = {}
p_arr = np.asarray(percentiles, dtype=np.float64)
for metric in BAND_METRICS:
grid = metric_matrix(sweep_result, metric) # (R, S)
bands[metric] = np.asarray(np.nanpercentile(grid, p_arr, axis=0), dtype=np.float64)
return BandResult(
percentiles=tuple(int(p) for p in percentiles),
shifts=shifts_sorted,
bands=bands,
)