# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations
import filecmp
import tomllib
from pathlib import Path
import pytest
from iohmm_evac.scenarios import list_scenarios
from iohmm_evac.sweep import (
DEFAULT_SCENARIOS,
SweepConfig,
load_sweep,
run_sweep,
)
@pytest.fixture(scope="module")
def small_sweep(tmp_path_factory: pytest.TempPathFactory) -> Path:
# N=2000 keeps the sweep under 2s while making the peak-hour estimate
# stable enough for the early-warning vs. baseline ordering check.
# Module-scoped: every test in this file consumes the same on-disk
# sweep, since none mutate it.
out = tmp_path_factory.mktemp("small_sweep")
config = SweepConfig(
output_dir=out / "sweep",
scenarios=DEFAULT_SCENARIOS,
seed=0,
n_households=2000,
n_hours=120,
)
run_sweep(config)
return config.output_dir
def test_run_sweep_writes_expected_directory_structure(small_sweep: Path) -> None:
assert (small_sweep / "sweep.toml").exists()
for scenario in DEFAULT_SCENARIOS:
scenario_dir = small_sweep / scenario
assert scenario_dir.is_dir()
for name in (
"observations.parquet",
"observations.population.parquet",
"observations.timeline.parquet",
"observations.config.toml",
"network_metrics.toml",
):
target = scenario_dir / name
assert target.exists(), f"missing {target}"
assert target.stat().st_size > 0
def test_load_sweep_round_trips_run_sweep(small_sweep: Path) -> None:
loaded = load_sweep(small_sweep)
assert tuple(loaded.config.scenarios) == DEFAULT_SCENARIOS
assert loaded.config.seed == 0
assert set(loaded.bundles.keys()) == set(DEFAULT_SCENARIOS)
assert set(loaded.network_metrics.keys()) == set(DEFAULT_SCENARIOS)
for scenario in DEFAULT_SCENARIOS:
assert loaded.bundles[scenario].exists()
m = loaded.network_metrics[scenario]
assert m.delay_per_hour.shape == (loaded.config.n_hours + 1,)
assert m.enroute_count_per_hour.shape == (loaded.config.n_hours + 1,)
assert m.arrivals_away_per_hour.shape == (loaded.config.n_hours + 1,)
def test_top_level_marker_contents(small_sweep: Path) -> None:
with (small_sweep / "sweep.toml").open("rb") as f:
data = tomllib.load(f)
assert data["seed"] == 0
assert data["n_hours"] == 120
assert data["n_households"] == 2000
assert tuple(data["scenarios"]) == DEFAULT_SCENARIOS
assert "version" in data
def test_scenarios_produce_distinct_metrics(small_sweep: Path) -> None:
loaded = load_sweep(small_sweep)
metrics_by_scenario = loaded.network_metrics
# At least one numerical metric must differ between baseline and another
# scenario — otherwise the scenarios are accidentally identical.
baseline = metrics_by_scenario["baseline"]
diffs: dict[str, float] = {}
for name in DEFAULT_SCENARIOS:
if name == "baseline":
continue
other = metrics_by_scenario[name]
diffs[name] = abs(other.total_delay_hours - baseline.total_delay_hours) + abs(
other.failed_evacuation_count - baseline.failed_evacuation_count
)
assert any(d > 0 for d in diffs.values()), f"all scenarios produced identical metrics: {diffs}"
def test_early_warning_shifts_peak_enroute_earlier(small_sweep: Path) -> None:
loaded = load_sweep(small_sweep)
baseline = loaded.network_metrics["baseline"]
early = loaded.network_metrics["early-warning"]
# Earlier orders should pull the road-network peak earlier. Use a
# tolerance of >5 hours to absorb stochastic variation at the small N
# used by the test fixture.
assert early.peak_enroute_hour < baseline.peak_enroute_hour - 5, (
f"early-warning peak ({early.peak_enroute_hour}) should be "
f">5h earlier than baseline ({baseline.peak_enroute_hour})"
)
def test_same_seed_reproduces_baseline_bit_for_bit(tmp_path: Path) -> None:
cfg_a = SweepConfig(
output_dir=tmp_path / "a",
scenarios=("baseline",),
seed=7,
n_households=200,
n_hours=24,
)
cfg_b = SweepConfig(
output_dir=tmp_path / "b",
scenarios=("baseline",),
seed=7,
n_households=200,
n_hours=24,
)
run_sweep(cfg_a)
run_sweep(cfg_b)
a = cfg_a.output_dir / "baseline" / "observations.parquet"
b = cfg_b.output_dir / "baseline" / "observations.parquet"
assert filecmp.cmp(a, b, shallow=False)
def test_run_sweep_subset_of_scenarios(tmp_path: Path) -> None:
config = SweepConfig(
output_dir=tmp_path / "subset",
scenarios=("baseline", "contraflow"),
seed=0,
n_households=200,
n_hours=24,
)
result = run_sweep(config)
assert set(result.bundles.keys()) == {"baseline", "contraflow"}
assert set(result.network_metrics.keys()) == {"baseline", "contraflow"}
assert (config.output_dir / "early-warning").exists() is False
def test_unknown_scenario_raises(tmp_path: Path) -> None:
config = SweepConfig(
output_dir=tmp_path / "bad",
scenarios=("baseline", "does-not-exist"),
seed=0,
n_households=10,
n_hours=2,
)
with pytest.raises(ValueError, match="Unknown scenario"):
run_sweep(config)
def test_load_sweep_missing_marker(tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match=r"sweep\.toml"):
load_sweep(tmp_path / "no-such")
def test_default_scenarios_match_registry() -> None:
assert set(DEFAULT_SCENARIOS) == set(list_scenarios())