# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations
import numpy as np
from iohmm_evac.dgp.emissions import sample_emissions
from iohmm_evac.params import EmissionParams
from iohmm_evac.types import State
def _make_state_vector(s: State, n: int) -> np.ndarray:
return np.full(n, int(s), dtype=np.int64)
def test_emission_shapes() -> None:
rng = np.random.default_rng(0)
n = 50
state = _make_state_vector(State.AW, n)
evac_path = np.zeros(n, dtype=np.int64)
tir = np.zeros(n)
dest = np.full(n, 60.0)
d, x, c = sample_emissions(state, evac_path, tir, dest, 0.0, EmissionParams(), rng)
assert d.shape == (n,) and d.dtype == bool
assert x.shape == (n,)
assert c.shape == (n,)
def test_departure_rate_matches_state() -> None:
rng = np.random.default_rng(1)
n = 5_000
state = _make_state_vector(State.ER, n)
d, _, _ = sample_emissions(
state, np.zeros(n, np.int64), np.zeros(n), np.full(n, 60.0), 0.0, EmissionParams(), rng
)
# ER households depart with p ≈ 0.95.
assert abs(d.mean() - 0.95) < 0.02
def test_displacement_increases_with_tir_for_er() -> None:
n = 2_000
state = _make_state_vector(State.ER, n)
dest = np.full(n, 80.0)
evac_path = np.zeros(n, dtype=np.int64)
_, x_low, _ = sample_emissions(
state, evac_path, np.full(n, 0.5), dest, 0.0, EmissionParams(), np.random.default_rng(3)
)
_, x_high, _ = sample_emissions(
state, evac_path, np.full(n, 1.5), dest, 0.0, EmissionParams(), np.random.default_rng(3)
)
assert x_high.mean() > x_low.mean()
def test_displacement_capped_at_destination() -> None:
rng = np.random.default_rng(4)
n = 1_000
state = _make_state_vector(State.ER, n)
dest = np.full(n, 50.0)
# Very large tir should cap at destination + small noise.
_, x, _ = sample_emissions(
state, np.zeros(n, np.int64), np.full(n, 100.0), dest, 0.0, EmissionParams(), rng
)
# After 100h at 40 km/h, raw progress is 4000; min with dest = 50.
assert abs(x.mean() - 50.0) < 0.2
def test_poisson_rates_per_state() -> None:
rng = np.random.default_rng(5)
n = 4_000
params = EmissionParams()
expected = {
State.UA: params.lambda_ua,
State.AW: params.lambda_aw,
State.PR: params.lambda_pr,
State.ER: params.lambda_er,
State.SH: params.lambda_sh,
}
for s, lam in expected.items():
state = _make_state_vector(s, n)
_, _, c = sample_emissions(
state,
np.zeros(n, np.int64),
np.zeros(n),
np.full(n, 60.0),
0.0,
params,
rng,
)
assert abs(c.mean() - lam) < 0.2
def test_sh_displacement_branches_on_evac_path() -> None:
rng = np.random.default_rng(6)
n = 1_000
state = _make_state_vector(State.SH, n)
dest = np.full(n, 70.0)
away = np.full(n, 1, dtype=np.int64)
home = np.full(n, 2, dtype=np.int64)
_, x_away, _ = sample_emissions(state, away, np.zeros(n), dest, 0.0, EmissionParams(), rng)
_, x_home, _ = sample_emissions(
state, home, np.zeros(n), dest, 0.0, EmissionParams(), np.random.default_rng(7)
)
assert abs(x_away.mean() - 70.0) < 0.5
assert abs(x_home.mean()) < 1.0