# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations
import numpy as np
import pytest
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.params import SimulationConfig
from iohmm_evac.scenarios import build_scenario, list_scenarios
from iohmm_evac.types import State
def test_scenarios_listed() -> None:
expected = {"baseline", "early-warning", "targeted-messaging", "contraflow"}
assert set(list_scenarios()) == expected
def test_each_scenario_builds_a_valid_config() -> None:
for name in list_scenarios():
cfg = build_scenario(name)
assert isinstance(cfg, SimulationConfig)
assert cfg.n_households > 0
assert cfg.n_hours > 0
def test_unknown_scenario_raises() -> None:
with pytest.raises(ValueError, match="Unknown scenario"):
build_scenario("not-a-scenario")
def test_early_warning_pulls_departures_earlier() -> None:
base = build_scenario("baseline")
early = build_scenario("early-warning")
# Use a smallish cohort and average a couple of seeds to keep the test fast.
n = 1_000
t_total = 96 # before landfall hour
departures_base = []
departures_early = []
for seed in (0, 1, 2):
b = build_scenario("baseline").__class__(
n_households=n,
n_hours=t_total,
seed=seed,
population=base.population,
timeline=base.timeline,
transitions=base.transitions,
emissions=base.emissions,
feedback=base.feedback,
)
e = SimulationConfig(
n_households=n,
n_hours=t_total,
seed=seed,
population=early.population,
timeline=early.timeline,
transitions=early.transitions,
emissions=early.emissions,
feedback=early.feedback,
)
rb = simulate(b, np.random.default_rng(seed))
re_ = simulate(e, np.random.default_rng(seed))
def first_departure_time(states: np.ndarray) -> np.ndarray:
on_route = states == int(State.ER)
ever = on_route.any(axis=1)
t_first = np.where(ever, on_route.argmax(axis=1), states.shape[1])
return np.asarray(t_first[ever])
departures_base.append(first_departure_time(rb.states).mean())
departures_early.append(first_departure_time(re_.states).mean())
assert np.mean(departures_early) < np.mean(departures_base)
def test_contraflow_raises_capacity() -> None:
cfg = build_scenario("contraflow")
assert cfg.feedback.n_cap == 2500
def test_targeted_messaging_zone_multiplier() -> None:
cfg = build_scenario("targeted-messaging")
assert cfg.population.targeted_zone_multiplier == 1.5