# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations
import numpy as np
from iohmm_evac.dgp.timeline import build_timeline, local_risk_at
from iohmm_evac.params import TimelineParams
def test_forecast_piecewise_means() -> None:
rng = np.random.default_rng(0)
tl = build_timeline(120, rng, TimelineParams(forecast_noise_sigma=0.0))
# With zero noise, the forecast equals the piecewise plan exactly.
assert tl.forecast.shape == (121,)
assert np.allclose(tl.forecast[:48], 1.0)
assert np.allclose(tl.forecast[48:72], 2.0)
assert np.allclose(tl.forecast[72:96], 3.0)
assert np.allclose(tl.forecast[96:], 4.0)
def test_warning_orders_monotonic() -> None:
rng = np.random.default_rng(0)
tl = build_timeline(120, rng, TimelineParams())
# Both indicators are non-decreasing (they switch on once and stay on).
assert np.all(np.diff(tl.voluntary) >= 0)
assert np.all(np.diff(tl.mandatory) >= 0)
assert tl.voluntary[60] == 1
assert tl.voluntary[59] == 0
assert tl.mandatory[84] == 1
assert tl.mandatory[83] == 0
def test_time_since_order_resets() -> None:
rng = np.random.default_rng(0)
tl = build_timeline(120, rng, TimelineParams(voluntary_hour=10, mandatory_hour=20))
# Before first event the sentinel kicks in.
assert tl.time_since_order[5] > 120
# Reset at voluntary
assert tl.time_since_order[10] == 0
assert tl.time_since_order[15] == 5
# Reset at mandatory
assert tl.time_since_order[20] == 0
assert tl.time_since_order[25] == 5
def test_local_risk_decay_with_distance() -> None:
distance = np.array([0.0, 10.0, 20.0])
rho = local_risk_at(2.0, distance)
expected = 2.0 * np.exp(-distance / 10.0)
np.testing.assert_allclose(rho, expected)
def test_local_risk_zone_multiplier() -> None:
distance = np.array([1.0, 1.0, 1.0])
mult = np.array([1.0, 2.0, 1.0])
rho = local_risk_at(1.0, distance, zone_multiplier=mult)
np.testing.assert_allclose(rho, np.array([1.0, 2.0, 1.0]) * np.exp(-0.1))
def test_timeline_param_validation() -> None:
import pytest
rng = np.random.default_rng(0)
with pytest.raises(ValueError, match="equal length"):
build_timeline(
10, rng, TimelineParams(forecast_breakpoints=(0,), forecast_levels=(1.0, 2.0))
)
with pytest.raises(ValueError, match="must start at 0"):
build_timeline(
10, rng, TimelineParams(forecast_breakpoints=(1, 5), forecast_levels=(1.0, 2.0))
)