src/iohmm_evac/bootstrap/shift_sweep.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Shift-sweep driver: simulate failed-evac counts under warning-time shifts.

For each ``(replicate, shift)`` pair, build a :class:`SimulationConfig` with
the warning timeline shifted by ``shift`` hours, swap the simulator's
transition logits for the replicate's IO-HMM-fitted parameters, run the
simulation, and compute the four network metrics. The IO-HMM-fitted
transitions have the same multinomial-logit form as the DGP rows, but the
IO-HMM input vector is a strict subset (no endogenous ``π``, ``c``, ``tir``).
"""

from __future__ import annotations

from dataclasses import dataclass, replace
from pathlib import Path

import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from numpy.random import Generator

from iohmm_evac.bootstrap.runner import BootstrapFit
from iohmm_evac.dgp.emissions import sample_emissions
from iohmm_evac.dgp.feedback import congestion
from iohmm_evac.dgp.population import synthesize_population
from iohmm_evac.dgp.simulator import SimulationResult
from iohmm_evac.dgp.timeline import build_timeline
from iohmm_evac.inference.fit_params import FEATURE_NAMES, F, FitParameters, K
from iohmm_evac.inference.log_space import LOG_EPS
from iohmm_evac.network.metrics import NetworkMetrics, compute_metrics_from_arrays
from iohmm_evac.params import SimulationConfig, TimelineParams
from iohmm_evac.types import FloatArray, IntArray, State

__all__ = [
    "ShiftSweepResult",
    "SweepRow",
    "load_sweep_result",
    "run_shift_sweep",
    "shift_timeline",
    "simulate_with_iohmm_transitions",
    "write_sweep_result",
]


_FEATURE_INDEX: dict[str, int] = {name: i for i, name in enumerate(FEATURE_NAMES)}
_SH = int(State.SH)


@dataclass(frozen=True, slots=True)
class SweepRow:
    """One row of the long-format shift-sweep output table."""

    replicate_id: int
    shift: int
    failed_evacuation_count: int
    peak_enroute_share: float
    total_delay_hours: float
    shelter_overflow_count: int


@dataclass(frozen=True, slots=True)
class ShiftSweepResult:
    """Result of a bootstrap × shift sweep."""

    rows: tuple[SweepRow, ...]
    shifts: tuple[int, ...]
    n_replicates: int


def shift_timeline(timeline: TimelineParams, shift: int) -> TimelineParams:
    """Return ``timeline`` with voluntary and mandatory hours shifted by ``shift``."""
    return replace(
        timeline,
        voluntary_hour=int(timeline.voluntary_hour) + int(shift),
        mandatory_hour=int(timeline.mandatory_hour) + int(shift),
    )


def _iohmm_input_vector(
    forecast_t: float,
    vol_t: int,
    mand_t: int,
    distance: FloatArray,
    risk: FloatArray,
    vehicle: FloatArray,
    tau: float,
) -> FloatArray:
    u = np.zeros((distance.shape[0], F), dtype=np.float64)
    u[:, _FEATURE_INDEX["vol"]] = float(vol_t)
    u[:, _FEATURE_INDEX["mand"]] = float(mand_t)
    u[:, _FEATURE_INDEX["rho"]] = float(forecast_t) * np.exp(-distance / 10.0)
    u[:, _FEATURE_INDEX["r"]] = risk
    u[:, _FEATURE_INDEX["v"]] = vehicle
    u[:, _FEATURE_INDEX["tau"]] = float(tau)
    return u


def _sample_iohmm_step(
    prev_state: IntArray,
    u_t: FloatArray,
    fit_params: FitParameters,
    rng: Generator,
) -> IntArray:
    alpha = fit_params.transitions.alpha
    beta = fit_params.transitions.beta
    new_state = prev_state.copy()
    for k in range(K):
        idx = np.flatnonzero(prev_state == k)
        if idx.size == 0 or k == _SH:
            continue
        u = u_t[idx]
        logits = alpha[k][None, :] + np.einsum("nf,jf->nj", u, beta[k])
        forbidden = ~np.isfinite(alpha[k])
        if forbidden.any():
            logits = np.where(forbidden[None, :], LOG_EPS, logits)
        m = np.max(logits, axis=1, keepdims=True)
        e = np.exp(logits - m)
        probs = e / e.sum(axis=1, keepdims=True)
        cum = np.cumsum(probs, axis=1)
        cum[:, -1] = 1.0
        u_rand = np.asarray(rng.random(size=idx.shape[0]), dtype=np.float64).reshape(-1, 1)
        choice = np.argmax(u_rand < cum, axis=1).astype(np.int64)
        new_state[idx] = choice
    return new_state


def _sample_initial_state(initial_probs: FloatArray, n: int, rng: Generator) -> IntArray:
    cum = np.cumsum(initial_probs)
    cum[-1] = 1.0
    u = np.asarray(rng.random(size=n), dtype=np.float64)
    return np.asarray(np.searchsorted(cum, u, side="right"), dtype=np.int64).clip(0, K - 1)


def simulate_with_iohmm_transitions(
    config: SimulationConfig,
    fit_params: FitParameters,
    rng: Generator,
) -> SimulationResult:
    """Run the simulator with IO-HMM transitions in place of the DGP transitions.

    Population, timeline, emissions, and feedback come from ``config``;
    transitions are sampled from the IO-HMM logit form
    ``α[k, j] + β[k, j, :] · u_{i, t}``.
    """
    n = config.n_households
    t_total = config.n_hours
    pop = synthesize_population(n, rng, config.population)
    timeline = build_timeline(t_total, rng, config.timeline)

    states = np.zeros((n, t_total + 1), dtype=np.int64)
    departures = np.zeros((n, t_total + 1), dtype=bool)
    displacements = np.zeros((n, t_total + 1), dtype=np.float64)
    communications = np.zeros((n, t_total + 1), dtype=np.int64)
    evac_path = np.zeros(n, dtype=np.int64)
    tir = np.zeros(n, dtype=np.float64)

    states[:, 0] = _sample_initial_state(fit_params.initial.probs(), n, rng)
    departures[:, 0], displacements[:, 0], communications[:, 0] = sample_emissions(
        states[:, 0], evac_path, tir, pop.destination, 0.0, config.emissions, rng
    )
    for t in range(1, t_total + 1):
        prev = states[:, t - 1]
        c_t = congestion(prev, config.feedback.n_cap)
        u_t = _iohmm_input_vector(
            float(timeline.forecast[t]),
            int(timeline.voluntary[t]),
            int(timeline.mandatory[t]),
            pop.distance,
            pop.risk,
            pop.vehicle,
            t / t_total,
        )
        new_state = _sample_iohmm_step(prev, u_t, fit_params, rng)
        evac_path[(prev == State.PR) & (new_state == State.ER)] = 1
        evac_path[(prev == State.PR) & (new_state == State.SH)] = 2
        tir = np.where(new_state == State.ER, tir + 1.0, 0.0)
        states[:, t] = new_state
        departures[:, t], displacements[:, t], communications[:, t] = sample_emissions(
            new_state, evac_path, tir, pop.destination, c_t, config.emissions, rng
        )

    return SimulationResult(
        states=states,
        departures=departures,
        displacements=displacements,
        communications=communications,
        population=pop,
        timeline=timeline,
        evac_path=evac_path,
        config=config,
    )


def _shift_seed(base_seed: int, replicate_id: int, shift_index: int) -> int:
    """Per ``(replicate, shift)`` deterministic seed: ``b*10000 + r*100 + s``."""
    return int(base_seed) * 10_000 + int(replicate_id) * 100 + int(shift_index)


def _row_from_metrics(replicate_id: int, shift: int, metrics: NetworkMetrics) -> SweepRow:
    return SweepRow(
        replicate_id=int(replicate_id),
        shift=int(shift),
        failed_evacuation_count=int(metrics.failed_evacuation_count),
        peak_enroute_share=float(metrics.peak_enroute_share),
        total_delay_hours=float(metrics.total_delay_hours),
        shelter_overflow_count=int(metrics.shelter_overflow_count),
    )


def run_shift_sweep(
    bootstrap_fits: list[BootstrapFit],
    shifts: tuple[int, ...],
    scenario_base: SimulationConfig,
    n_households: int,
    n_hours: int,
    base_seed: int,
) -> ShiftSweepResult:
    """Run a ``(replicate × shift)`` sweep and return the long-format result."""
    if not bootstrap_fits:
        msg = "bootstrap_fits must not be empty"
        raise ValueError(msg)
    if not shifts:
        msg = "shifts must not be empty"
        raise ValueError(msg)

    rows: list[SweepRow] = []
    for fit in bootstrap_fits:
        for shift_index, shift in enumerate(shifts):
            shifted = replace(
                scenario_base,
                n_households=int(n_households),
                n_hours=int(n_hours),
                timeline=shift_timeline(scenario_base.timeline, int(shift)),
                seed=_shift_seed(base_seed, fit.replicate_id, shift_index),
            )
            sim = simulate_with_iohmm_transitions(
                shifted, fit.params, np.random.default_rng(shifted.seed)
            )
            metrics = compute_metrics_from_arrays(
                states=sim.states,
                displacements=sim.displacements,
                evac_path=sim.evac_path,
                n_cap=int(scenario_base.feedback.n_cap),
                shelter_capacity=int(scenario_base.feedback.shelter_capacity),
                v_free=float(scenario_base.emissions.v_free),
                congestion_penalty=float(scenario_base.emissions.congestion_penalty),
            )
            rows.append(_row_from_metrics(fit.replicate_id, int(shift), metrics))

    return ShiftSweepResult(
        rows=tuple(rows),
        shifts=tuple(int(s) for s in shifts),
        n_replicates=len(bootstrap_fits),
    )


_INT_COLS: tuple[str, ...] = ("failed_evacuation_count", "shelter_overflow_count")
_FLOAT_COLS: tuple[str, ...] = ("peak_enroute_share", "total_delay_hours")


def write_sweep_result(result: ShiftSweepResult, output_path: Path) -> Path:
    """Write the sweep result as a long-format parquet."""
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    rows = result.rows
    columns: dict[str, pa.Array] = {
        "replicate_id": pa.array([r.replicate_id for r in rows], type=pa.int32()),
        "shift": pa.array([r.shift for r in rows], type=pa.int32()),
    }
    for name in _INT_COLS:
        columns[name] = pa.array([getattr(r, name) for r in rows], type=pa.int64())
    for name in _FLOAT_COLS:
        columns[name] = pa.array([getattr(r, name) for r in rows], type=pa.float64())
    pq.write_table(pa.table(columns), output_path)  # type: ignore[no-untyped-call]
    return output_path


def load_sweep_result(input_path: Path) -> ShiftSweepResult:
    """Read a sweep result back from a parquet written by :func:`write_sweep_result`."""
    table = pq.read_table(Path(input_path))  # type: ignore[no-untyped-call]
    cols = {name: table.column(name).to_pylist() for name in table.schema.names}
    rows = tuple(
        SweepRow(
            replicate_id=int(cols["replicate_id"][i]),
            shift=int(cols["shift"][i]),
            failed_evacuation_count=int(cols["failed_evacuation_count"][i]),
            peak_enroute_share=float(cols["peak_enroute_share"][i]),
            total_delay_hours=float(cols["total_delay_hours"][i]),
            shelter_overflow_count=int(cols["shelter_overflow_count"][i]),
        )
        for i in range(len(cols["replicate_id"]))
    )
    return ShiftSweepResult(
        rows=rows,
        shifts=tuple(sorted({int(s) for s in cols["shift"]})),
        n_replicates=len({int(r) for r in cols["replicate_id"]}),
    )