tests/test_timeline.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
from __future__ import annotations

import numpy as np

from iohmm_evac.dgp.timeline import build_timeline, local_risk_at
from iohmm_evac.params import TimelineParams


def test_forecast_piecewise_means() -> None:
    rng = np.random.default_rng(0)
    tl = build_timeline(120, rng, TimelineParams(forecast_noise_sigma=0.0))
    # With zero noise, the forecast equals the piecewise plan exactly.
    assert tl.forecast.shape == (121,)
    assert np.allclose(tl.forecast[:48], 1.0)
    assert np.allclose(tl.forecast[48:72], 2.0)
    assert np.allclose(tl.forecast[72:96], 3.0)
    assert np.allclose(tl.forecast[96:], 4.0)


def test_warning_orders_monotonic() -> None:
    rng = np.random.default_rng(0)
    tl = build_timeline(120, rng, TimelineParams())
    # Both indicators are non-decreasing (they switch on once and stay on).
    assert np.all(np.diff(tl.voluntary) >= 0)
    assert np.all(np.diff(tl.mandatory) >= 0)
    assert tl.voluntary[60] == 1
    assert tl.voluntary[59] == 0
    assert tl.mandatory[84] == 1
    assert tl.mandatory[83] == 0


def test_time_since_order_resets() -> None:
    rng = np.random.default_rng(0)
    tl = build_timeline(120, rng, TimelineParams(voluntary_hour=10, mandatory_hour=20))
    # Before first event the sentinel kicks in.
    assert tl.time_since_order[5] > 120
    # Reset at voluntary
    assert tl.time_since_order[10] == 0
    assert tl.time_since_order[15] == 5
    # Reset at mandatory
    assert tl.time_since_order[20] == 0
    assert tl.time_since_order[25] == 5


def test_local_risk_decay_with_distance() -> None:
    distance = np.array([0.0, 10.0, 20.0])
    rho = local_risk_at(2.0, distance)
    expected = 2.0 * np.exp(-distance / 10.0)
    np.testing.assert_allclose(rho, expected)


def test_local_risk_zone_multiplier() -> None:
    distance = np.array([1.0, 1.0, 1.0])
    mult = np.array([1.0, 2.0, 1.0])
    rho = local_risk_at(1.0, distance, zone_multiplier=mult)
    np.testing.assert_allclose(rho, np.array([1.0, 2.0, 1.0]) * np.exp(-0.1))


def test_timeline_param_validation() -> None:
    import pytest

    rng = np.random.default_rng(0)
    with pytest.raises(ValueError, match="equal length"):
        build_timeline(
            10, rng, TimelineParams(forecast_breakpoints=(0,), forecast_levels=(1.0, 2.0))
        )
    with pytest.raises(ValueError, match="must start at 0"):
        build_timeline(
            10, rng, TimelineParams(forecast_breakpoints=(1, 5), forecast_levels=(1.0, 2.0))
        )