# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Run all four scenarios under a common seed and bundle the outputs.
Sweep directory layout::
output/sweep/
├── baseline/
│ ├── observations.parquet
│ ├── observations.population.parquet
│ ├── observations.timeline.parquet
│ ├── observations.config.toml
│ └── network_metrics.toml
├── early-warning/
│ └── (same)
├── targeted-messaging/
├── contraflow/
└── sweep.toml
Each sub-directory is exactly what :func:`iohmm_evac.io.write_results`
produces, with one extra TOML sidecar carrying the post-hoc network metrics.
"""
from __future__ import annotations
import tomllib
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import Any
import numpy as np
import tomli_w
from iohmm_evac import __version__
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.io import write_results
from iohmm_evac.network import NetworkMetrics, compute_network_metrics
from iohmm_evac.report.loader import load_bundle
from iohmm_evac.scenarios import build_scenario, list_scenarios
__all__ = [
"DEFAULT_SCENARIOS",
"SweepConfig",
"SweepResult",
"load_sweep",
"run_sweep",
]
DEFAULT_SCENARIOS: tuple[str, ...] = (
"baseline",
"early-warning",
"targeted-messaging",
"contraflow",
)
@dataclass(frozen=True, slots=True)
class SweepConfig:
"""Configuration for a multi-scenario sweep."""
output_dir: Path
scenarios: tuple[str, ...] = field(default=DEFAULT_SCENARIOS)
seed: int = 0
n_households: int = 10_000
n_hours: int = 120
@dataclass(frozen=True, slots=True)
class SweepResult:
"""Result of a sweep run, suitable for passing to plotting code."""
bundles: dict[str, Path]
"""Scenario name → main observations Parquet path."""
network_metrics: dict[str, NetworkMetrics]
"""Scenario name → computed :class:`NetworkMetrics`."""
config: SweepConfig
"""The :class:`SweepConfig` that produced this sweep."""
_OBSERVATIONS_NAME = "observations.parquet"
_NETWORK_METRICS_NAME = "network_metrics.toml"
_TOP_LEVEL_NAME = "sweep.toml"
def _scenario_dir(output_dir: Path, scenario: str) -> Path:
return output_dir / scenario
def _validate_scenarios(names: tuple[str, ...]) -> None:
known = set(list_scenarios())
unknown = [n for n in names if n not in known]
if unknown:
msg = f"Unknown scenario(s): {unknown}. Known: {sorted(known)}"
raise ValueError(msg)
def _network_metrics_to_toml(metrics: NetworkMetrics) -> bytes:
"""Serialize :class:`NetworkMetrics` to a TOML payload.
Per-hour diagnostic arrays are stored as plain lists so they can be
round-tripped without numpy dependencies on the reader side.
"""
payload: dict[str, Any] = {
"summary": {
"total_delay_hours": metrics.total_delay_hours,
"peak_enroute_share": metrics.peak_enroute_share,
"peak_enroute_hour": metrics.peak_enroute_hour,
"shelter_overflow_count": metrics.shelter_overflow_count,
"failed_evacuation_count": metrics.failed_evacuation_count,
},
"diagnostics": {
"delay_per_hour": [float(x) for x in metrics.delay_per_hour.tolist()],
"enroute_count_per_hour": [int(x) for x in metrics.enroute_count_per_hour.tolist()],
"arrivals_away_per_hour": [int(x) for x in metrics.arrivals_away_per_hour.tolist()],
},
}
return tomli_w.dumps(payload).encode("utf-8")
def _network_metrics_from_dict(data: dict[str, Any]) -> NetworkMetrics:
summary = data["summary"]
diagnostics = data["diagnostics"]
return NetworkMetrics(
total_delay_hours=float(summary["total_delay_hours"]),
peak_enroute_share=float(summary["peak_enroute_share"]),
peak_enroute_hour=int(summary["peak_enroute_hour"]),
shelter_overflow_count=int(summary["shelter_overflow_count"]),
failed_evacuation_count=int(summary["failed_evacuation_count"]),
delay_per_hour=np.asarray(diagnostics["delay_per_hour"], dtype=np.float64),
enroute_count_per_hour=np.asarray(diagnostics["enroute_count_per_hour"], dtype=np.int64),
arrivals_away_per_hour=np.asarray(diagnostics["arrivals_away_per_hour"], dtype=np.int64),
)
def _top_level_payload(config: SweepConfig) -> bytes:
payload = {
"version": __version__,
"seed": int(config.seed),
"n_households": int(config.n_households),
"n_hours": int(config.n_hours),
"scenarios": list(config.scenarios),
}
return tomli_w.dumps(payload).encode("utf-8")
def _run_one_scenario(
scenario: str,
sweep_config: SweepConfig,
) -> tuple[Path, NetworkMetrics]:
sim_config = build_scenario(scenario)
sim_config = replace(
sim_config,
n_households=sweep_config.n_households,
n_hours=sweep_config.n_hours,
seed=sweep_config.seed,
)
rng = np.random.default_rng(sim_config.seed)
result = simulate(sim_config, rng)
scenario_dir = _scenario_dir(sweep_config.output_dir, scenario)
scenario_dir.mkdir(parents=True, exist_ok=True)
obs_path = scenario_dir / _OBSERVATIONS_NAME
write_results(result, obs_path)
bundle = load_bundle(obs_path)
metrics = compute_network_metrics(bundle)
(scenario_dir / _NETWORK_METRICS_NAME).write_bytes(_network_metrics_to_toml(metrics))
return obs_path, metrics
def run_sweep(config: SweepConfig) -> SweepResult:
"""Run every scenario in ``config.scenarios`` and write outputs to disk."""
_validate_scenarios(config.scenarios)
config.output_dir.mkdir(parents=True, exist_ok=True)
bundles: dict[str, Path] = {}
metrics: dict[str, NetworkMetrics] = {}
for scenario in config.scenarios:
obs_path, scenario_metrics = _run_one_scenario(scenario, config)
bundles[scenario] = obs_path
metrics[scenario] = scenario_metrics
(config.output_dir / _TOP_LEVEL_NAME).write_bytes(_top_level_payload(config))
return SweepResult(bundles=bundles, network_metrics=metrics, config=config)
def load_sweep(output_dir: Path) -> SweepResult:
"""Load a sweep produced by :func:`run_sweep` from disk."""
output_dir = Path(output_dir)
top_path = output_dir / _TOP_LEVEL_NAME
if not top_path.exists():
msg = f"Sweep marker not found: {top_path}"
raise FileNotFoundError(msg)
with top_path.open("rb") as f:
top = tomllib.load(f)
scenarios = tuple(str(s) for s in top["scenarios"])
config = SweepConfig(
output_dir=output_dir,
scenarios=scenarios,
seed=int(top["seed"]),
n_households=int(top["n_households"]),
n_hours=int(top["n_hours"]),
)
bundles: dict[str, Path] = {}
metrics: dict[str, NetworkMetrics] = {}
for scenario in scenarios:
scenario_dir = _scenario_dir(output_dir, scenario)
obs_path = scenario_dir / _OBSERVATIONS_NAME
if not obs_path.exists():
msg = f"Missing observations parquet for scenario {scenario!r}: {obs_path}"
raise FileNotFoundError(msg)
nm_path = scenario_dir / _NETWORK_METRICS_NAME
if not nm_path.exists():
msg = f"Missing network metrics for scenario {scenario!r}: {nm_path}"
raise FileNotFoundError(msg)
with nm_path.open("rb") as f:
metrics[scenario] = _network_metrics_from_dict(tomllib.load(f))
bundles[scenario] = obs_path
return SweepResult(bundles=bundles, network_metrics=metrics, config=config)