# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations
import numpy as np
from iohmm_evac.network.metrics import (
NetworkMetrics,
compute_metrics_from_arrays,
peak_enroute_share_and_hour,
)
from iohmm_evac.types import State
ER = int(State.ER)
SH = int(State.SH)
UA = int(State.UA)
PR = int(State.PR)
AW = int(State.AW)
def _states(rows: list[list[int]]) -> np.ndarray:
return np.asarray(rows, dtype=np.int64)
def _disp(rows: list[list[float]]) -> np.ndarray:
return np.asarray(rows, dtype=np.float64)
def test_total_delay_matches_hand_computed_value() -> None:
# Two households, three hours. Household 0 is ER at t=1 and t=2.
states = _states(
[
[PR, ER, ER],
[PR, PR, ER],
]
)
displacements = _disp(
[
[0.0, 5.0, 12.0],
[0.0, 0.0, 4.0],
]
)
evac_path = np.array([1, 1], dtype=np.int64) # both AWAY
n_cap = 2
v_free = 40.0
alpha = 0.6
# c_t derived from states[:, t-1]:
# c_0 = 0 (no lag); c_1 = #ER at t=0 / n_cap = 0/2 = 0;
# c_2 = #ER at t=1 / n_cap = 1/2 = 0.5
# v_eff_0 = 40, v_eff_1 = 40, v_eff_2 = 40 * (1 - 0.6 * 0.5) = 28
# 1/v_eff_2 - 1/v_free = 1/28 - 1/40
# Only ER hours contribute.
# Household 0: t=1 ER, delta=5, c_1=0 → 0; t=2 ER, delta=7, contrib 7*(1/28-1/40)
# Household 1: t=2 ER, delta=4, contrib 4*(1/28-1/40)
expected = (7 + 4) * (1.0 / 28.0 - 1.0 / 40.0)
metrics = compute_metrics_from_arrays(
states=states,
displacements=displacements,
evac_path=evac_path,
n_cap=n_cap,
shelter_capacity=0,
v_free=v_free,
congestion_penalty=alpha,
)
assert metrics.total_delay_hours == np.float64(expected)
def test_peak_enroute_share_and_hour_helper() -> None:
states = _states(
[
[UA, ER, ER, SH],
[UA, AW, ER, ER],
[UA, AW, AW, ER],
]
)
# ER counts per hour: 0, 1, 2, 2; share = counts/3.
# First argmax tie at t=2.
share, hour = peak_enroute_share_and_hour(states)
assert hour == 2
assert share == 2.0 / 3.0
def test_peak_enroute_in_metrics_picks_correct_hour() -> None:
states = _states(
[
[PR, ER, ER, SH],
[PR, ER, SH, SH],
]
)
displacements = np.zeros_like(states, dtype=np.float64)
evac_path = np.array([1, 1], dtype=np.int64)
metrics = compute_metrics_from_arrays(
states=states,
displacements=displacements,
evac_path=evac_path,
n_cap=10,
shelter_capacity=10,
v_free=40.0,
)
# ER counts per hour: 0, 2, 1, 0 → peak at t=1, share 1.0.
assert metrics.peak_enroute_hour == 1
assert metrics.peak_enroute_share == 1.0
def test_peak_enroute_share_is_at_most_one() -> None:
# Construct a saturated case: every household ER at t=1.
states = _states(
[
[PR, ER],
[PR, ER],
[PR, ER],
]
)
displacements = np.zeros_like(states, dtype=np.float64)
evac_path = np.array([1, 1, 1], dtype=np.int64)
metrics = compute_metrics_from_arrays(
states=states,
displacements=displacements,
evac_path=evac_path,
n_cap=10,
shelter_capacity=10,
v_free=40.0,
)
assert 0.0 <= metrics.peak_enroute_share <= 1.0
assert metrics.peak_enroute_share == 1.0
def test_failed_evacuation_count_is_er_at_horizon() -> None:
states = _states(
[
[PR, ER, ER],
[PR, ER, SH],
[PR, PR, ER],
]
)
displacements = np.zeros_like(states, dtype=np.float64)
evac_path = np.array([1, 1, 1], dtype=np.int64)
metrics = compute_metrics_from_arrays(
states=states,
displacements=displacements,
evac_path=evac_path,
n_cap=10,
shelter_capacity=10,
v_free=40.0,
)
# At t=2, two households are ER (ids 0 and 2).
assert metrics.failed_evacuation_count == 2
def test_shelter_overflow_with_zero_capacity_counts_every_arrival() -> None:
states = _states(
[
[PR, ER, SH],
[PR, ER, SH],
[PR, PR, SH],
]
)
displacements = np.zeros_like(states, dtype=np.float64)
# Two households AWAY (ids 0, 1), one HOME (id 2).
evac_path = np.array([1, 1, 2], dtype=np.int64)
metrics = compute_metrics_from_arrays(
states=states,
displacements=displacements,
evac_path=evac_path,
n_cap=10,
shelter_capacity=0,
v_free=40.0,
)
# SH-away arrivals: only ids 0 and 1 enter SH with evac_path AWAY.
assert metrics.shelter_overflow_count == 2
def test_shelter_overflow_with_large_capacity_is_zero() -> None:
states = _states(
[
[PR, ER, SH],
[PR, ER, SH],
]
)
displacements = np.zeros_like(states, dtype=np.float64)
evac_path = np.array([1, 1], dtype=np.int64)
metrics = compute_metrics_from_arrays(
states=states,
displacements=displacements,
evac_path=evac_path,
n_cap=10,
shelter_capacity=1_000_000,
v_free=40.0,
)
assert metrics.shelter_overflow_count == 0
def test_total_delay_zero_with_huge_n_cap() -> None:
# A large n_cap pushes c_t toward zero, so v_eff ≈ v_free and the delay
# contribution should round-trip to ~0 (in fact exactly 0 since c_t = 0).
states = _states(
[
[PR, ER, ER],
[PR, ER, ER],
]
)
displacements = _disp(
[
[0.0, 5.0, 10.0],
[0.0, 6.0, 11.0],
]
)
evac_path = np.array([1, 1], dtype=np.int64)
metrics = compute_metrics_from_arrays(
states=states,
displacements=displacements,
evac_path=evac_path,
n_cap=10**9,
shelter_capacity=10,
v_free=40.0,
)
assert metrics.total_delay_hours < 1e-6
def test_diagnostic_arrays_shape_and_values() -> None:
states = _states(
[
[UA, AW, PR, ER, SH],
[UA, AW, AW, PR, ER],
]
)
displacements = np.zeros_like(states, dtype=np.float64)
evac_path = np.array([1, 1], dtype=np.int64)
metrics = compute_metrics_from_arrays(
states=states,
displacements=displacements,
evac_path=evac_path,
n_cap=5,
shelter_capacity=10,
v_free=40.0,
)
assert metrics.delay_per_hour.shape == (5,)
assert metrics.enroute_count_per_hour.shape == (5,)
assert metrics.arrivals_away_per_hour.shape == (5,)
# ER counts per hour: 0,0,0,1,1
expected_er = np.array([0, 0, 0, 1, 1], dtype=np.int64)
assert np.array_equal(metrics.enroute_count_per_hour, expected_er)
# SH-away arrivals: only id 0 enters SH at t=4. Hour-0 contribution is zero.
expected_arr = np.array([0, 0, 0, 0, 1], dtype=np.int64)
assert np.array_equal(metrics.arrivals_away_per_hour, expected_arr)
def test_metrics_dataclass_is_frozen() -> None:
states = _states([[PR, ER]])
metrics = compute_metrics_from_arrays(
states=states,
displacements=np.zeros_like(states, dtype=np.float64),
evac_path=np.array([1], dtype=np.int64),
n_cap=10,
shelter_capacity=10,
v_free=40.0,
)
assert isinstance(metrics, NetworkMetrics)
import dataclasses
with np.testing.assert_no_warnings():
pass
try:
dataclasses.replace(metrics, total_delay_hours=99.0) # OK on frozen
except Exception as exc: # pragma: no cover
raise AssertionError(f"frozen dataclass replace should succeed: {exc}") from exc
def test_n_cap_must_be_positive() -> None:
states = _states([[PR, ER]])
try:
compute_metrics_from_arrays(
states=states,
displacements=np.zeros_like(states, dtype=np.float64),
evac_path=np.array([1], dtype=np.int64),
n_cap=0,
shelter_capacity=10,
v_free=40.0,
)
except ValueError as exc:
assert "n_cap" in str(exc)
else: # pragma: no cover
raise AssertionError("expected ValueError")
def test_v_free_must_be_positive() -> None:
states = _states([[PR, ER]])
try:
compute_metrics_from_arrays(
states=states,
displacements=np.zeros_like(states, dtype=np.float64),
evac_path=np.array([1], dtype=np.int64),
n_cap=10,
shelter_capacity=10,
v_free=0.0,
)
except ValueError as exc:
assert "v_free" in str(exc)
else: # pragma: no cover
raise AssertionError("expected ValueError")