# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations
import numpy as np
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.params import SimulationConfig
from iohmm_evac.types import State
def test_simulator_end_to_end_shapes() -> None:
config = SimulationConfig(n_households=200, n_hours=24, seed=0)
rng = np.random.default_rng(0)
result = simulate(config, rng)
assert result.states.shape == (200, 25)
assert result.departures.shape == (200, 25)
assert result.displacements.shape == (200, 25)
assert result.communications.shape == (200, 25)
assert result.evac_path.shape == (200,)
def test_simulator_initial_state_is_ua() -> None:
config = SimulationConfig(n_households=100, n_hours=12, seed=0)
result = simulate(config, np.random.default_rng(0))
assert (result.states[:, 0] == int(State.UA)).all()
def test_simulator_states_visit_full_space() -> None:
config = SimulationConfig(n_households=2_000, n_hours=120, seed=0)
result = simulate(config, np.random.default_rng(0))
visited = set(int(x) for x in np.unique(result.states))
assert visited == {int(s) for s in State}
def test_simulator_terminal_sh_share_positive() -> None:
config = SimulationConfig(n_households=1_000, n_hours=120, seed=0)
result = simulate(config, np.random.default_rng(0))
sh_share = float((result.states[:, -1] == int(State.SH)).mean())
assert sh_share > 0.0
def test_simulator_sh_is_absorbing() -> None:
config = SimulationConfig(n_households=500, n_hours=120, seed=1)
result = simulate(config, np.random.default_rng(1))
states = result.states
# If household is in SH at t, it must be in SH at t+1.
sh_now = states[:, :-1] == int(State.SH)
sh_next = states[:, 1:] == int(State.SH)
assert ((~sh_now) | sh_next).all()
def test_simulator_evac_path_is_set_at_first_sh_entry() -> None:
config = SimulationConfig(n_households=500, n_hours=120, seed=2)
result = simulate(config, np.random.default_rng(2))
states = result.states
ever_sh = (states == int(State.SH)).any(axis=1)
ever_er = (states == int(State.ER)).any(axis=1)
# Every household that reached SH has evac_path != NONE (0).
assert (result.evac_path[ever_sh] != 0).all()
# evac_path is set at the PR -> {ER, SH} transition, so it is non-NONE
# iff the household ever entered ER or SH (not only at SH entry).
ever_evacuated = ever_sh | ever_er
assert (result.evac_path[~ever_evacuated] == 0).all()
assert (result.evac_path[ever_evacuated] != 0).all()
def test_simulator_reproducible_under_seed() -> None:
config = SimulationConfig(n_households=200, n_hours=24, seed=0)
r1 = simulate(config, np.random.default_rng(config.seed))
r2 = simulate(config, np.random.default_rng(config.seed))
np.testing.assert_array_equal(r1.states, r2.states)
np.testing.assert_array_equal(r1.departures, r2.departures)
np.testing.assert_array_equal(r1.communications, r2.communications)
np.testing.assert_allclose(r1.displacements, r2.displacements)