src/iohmm_evac/sweep.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Run all four scenarios under a common seed and bundle the outputs.

Sweep directory layout::

    output/sweep/
    ├── baseline/
    │   ├── observations.parquet
    │   ├── observations.population.parquet
    │   ├── observations.timeline.parquet
    │   ├── observations.config.toml
    │   └── network_metrics.toml
    ├── early-warning/
    │   └── (same)
    ├── targeted-messaging/
    ├── contraflow/
    └── sweep.toml

Each sub-directory is exactly what :func:`iohmm_evac.io.write_results`
produces, with one extra TOML sidecar carrying the post-hoc network metrics.
"""

from __future__ import annotations

import tomllib
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import Any

import numpy as np
import tomli_w

from iohmm_evac import __version__
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.io import write_results
from iohmm_evac.network import NetworkMetrics, compute_network_metrics
from iohmm_evac.report.loader import load_bundle
from iohmm_evac.scenarios import build_scenario, list_scenarios

__all__ = [
    "DEFAULT_SCENARIOS",
    "SweepConfig",
    "SweepResult",
    "load_sweep",
    "run_sweep",
]


DEFAULT_SCENARIOS: tuple[str, ...] = (
    "baseline",
    "early-warning",
    "targeted-messaging",
    "contraflow",
)


@dataclass(frozen=True, slots=True)
class SweepConfig:
    """Configuration for a multi-scenario sweep."""

    output_dir: Path
    scenarios: tuple[str, ...] = field(default=DEFAULT_SCENARIOS)
    seed: int = 0
    n_households: int = 10_000
    n_hours: int = 120


@dataclass(frozen=True, slots=True)
class SweepResult:
    """Result of a sweep run, suitable for passing to plotting code."""

    bundles: dict[str, Path]
    """Scenario name → main observations Parquet path."""

    network_metrics: dict[str, NetworkMetrics]
    """Scenario name → computed :class:`NetworkMetrics`."""

    config: SweepConfig
    """The :class:`SweepConfig` that produced this sweep."""


_OBSERVATIONS_NAME = "observations.parquet"
_NETWORK_METRICS_NAME = "network_metrics.toml"
_TOP_LEVEL_NAME = "sweep.toml"


def _scenario_dir(output_dir: Path, scenario: str) -> Path:
    return output_dir / scenario


def _validate_scenarios(names: tuple[str, ...]) -> None:
    known = set(list_scenarios())
    unknown = [n for n in names if n not in known]
    if unknown:
        msg = f"Unknown scenario(s): {unknown}. Known: {sorted(known)}"
        raise ValueError(msg)


def _network_metrics_to_toml(metrics: NetworkMetrics) -> bytes:
    """Serialize :class:`NetworkMetrics` to a TOML payload.

    Per-hour diagnostic arrays are stored as plain lists so they can be
    round-tripped without numpy dependencies on the reader side.
    """
    payload: dict[str, Any] = {
        "summary": {
            "total_delay_hours": metrics.total_delay_hours,
            "peak_enroute_share": metrics.peak_enroute_share,
            "peak_enroute_hour": metrics.peak_enroute_hour,
            "shelter_overflow_count": metrics.shelter_overflow_count,
            "failed_evacuation_count": metrics.failed_evacuation_count,
        },
        "diagnostics": {
            "delay_per_hour": [float(x) for x in metrics.delay_per_hour.tolist()],
            "enroute_count_per_hour": [int(x) for x in metrics.enroute_count_per_hour.tolist()],
            "arrivals_away_per_hour": [int(x) for x in metrics.arrivals_away_per_hour.tolist()],
        },
    }
    return tomli_w.dumps(payload).encode("utf-8")


def _network_metrics_from_dict(data: dict[str, Any]) -> NetworkMetrics:
    summary = data["summary"]
    diagnostics = data["diagnostics"]
    return NetworkMetrics(
        total_delay_hours=float(summary["total_delay_hours"]),
        peak_enroute_share=float(summary["peak_enroute_share"]),
        peak_enroute_hour=int(summary["peak_enroute_hour"]),
        shelter_overflow_count=int(summary["shelter_overflow_count"]),
        failed_evacuation_count=int(summary["failed_evacuation_count"]),
        delay_per_hour=np.asarray(diagnostics["delay_per_hour"], dtype=np.float64),
        enroute_count_per_hour=np.asarray(diagnostics["enroute_count_per_hour"], dtype=np.int64),
        arrivals_away_per_hour=np.asarray(diagnostics["arrivals_away_per_hour"], dtype=np.int64),
    )


def _top_level_payload(config: SweepConfig) -> bytes:
    payload = {
        "version": __version__,
        "seed": int(config.seed),
        "n_households": int(config.n_households),
        "n_hours": int(config.n_hours),
        "scenarios": list(config.scenarios),
    }
    return tomli_w.dumps(payload).encode("utf-8")


def _run_one_scenario(
    scenario: str,
    sweep_config: SweepConfig,
) -> tuple[Path, NetworkMetrics]:
    sim_config = build_scenario(scenario)
    sim_config = replace(
        sim_config,
        n_households=sweep_config.n_households,
        n_hours=sweep_config.n_hours,
        seed=sweep_config.seed,
    )
    rng = np.random.default_rng(sim_config.seed)
    result = simulate(sim_config, rng)

    scenario_dir = _scenario_dir(sweep_config.output_dir, scenario)
    scenario_dir.mkdir(parents=True, exist_ok=True)
    obs_path = scenario_dir / _OBSERVATIONS_NAME
    write_results(result, obs_path)

    bundle = load_bundle(obs_path)
    metrics = compute_network_metrics(bundle)
    (scenario_dir / _NETWORK_METRICS_NAME).write_bytes(_network_metrics_to_toml(metrics))
    return obs_path, metrics


def run_sweep(config: SweepConfig) -> SweepResult:
    """Run every scenario in ``config.scenarios`` and write outputs to disk."""
    _validate_scenarios(config.scenarios)
    config.output_dir.mkdir(parents=True, exist_ok=True)

    bundles: dict[str, Path] = {}
    metrics: dict[str, NetworkMetrics] = {}
    for scenario in config.scenarios:
        obs_path, scenario_metrics = _run_one_scenario(scenario, config)
        bundles[scenario] = obs_path
        metrics[scenario] = scenario_metrics

    (config.output_dir / _TOP_LEVEL_NAME).write_bytes(_top_level_payload(config))
    return SweepResult(bundles=bundles, network_metrics=metrics, config=config)


def load_sweep(output_dir: Path) -> SweepResult:
    """Load a sweep produced by :func:`run_sweep` from disk."""
    output_dir = Path(output_dir)
    top_path = output_dir / _TOP_LEVEL_NAME
    if not top_path.exists():
        msg = f"Sweep marker not found: {top_path}"
        raise FileNotFoundError(msg)
    with top_path.open("rb") as f:
        top = tomllib.load(f)

    scenarios = tuple(str(s) for s in top["scenarios"])
    config = SweepConfig(
        output_dir=output_dir,
        scenarios=scenarios,
        seed=int(top["seed"]),
        n_households=int(top["n_households"]),
        n_hours=int(top["n_hours"]),
    )

    bundles: dict[str, Path] = {}
    metrics: dict[str, NetworkMetrics] = {}
    for scenario in scenarios:
        scenario_dir = _scenario_dir(output_dir, scenario)
        obs_path = scenario_dir / _OBSERVATIONS_NAME
        if not obs_path.exists():
            msg = f"Missing observations parquet for scenario {scenario!r}: {obs_path}"
            raise FileNotFoundError(msg)
        nm_path = scenario_dir / _NETWORK_METRICS_NAME
        if not nm_path.exists():
            msg = f"Missing network metrics for scenario {scenario!r}: {nm_path}"
            raise FileNotFoundError(msg)
        with nm_path.open("rb") as f:
            metrics[scenario] = _network_metrics_from_dict(tomllib.load(f))
        bundles[scenario] = obs_path

    return SweepResult(bundles=bundles, network_metrics=metrics, config=config)