# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Build the exogenous forcing timeline (forecast, warning orders, etc.)."""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from numpy.random import Generator
from iohmm_evac.params import TimelineParams
from iohmm_evac.types import FloatArray, IntArray
__all__ = ["Timeline", "build_timeline", "local_risk_at"]
@dataclass(frozen=True, slots=True)
class Timeline:
"""Hourly exogenous inputs of length ``T+1``."""
forecast: FloatArray
"""Noisy forecast intensity, shape (T+1,)."""
voluntary: IntArray
"""Voluntary-evacuation indicator, shape (T+1,)."""
mandatory: IntArray
"""Mandatory-evacuation indicator, shape (T+1,)."""
time_since_order: FloatArray
"""Hours since the most recent escalation in (vol, mand), shape (T+1,).
At ``t=0`` no order has occurred; we use ``T+1`` as a sentinel "very large"
value so that downstream features expecting a small recency value treat
early time steps as effectively no-order.
"""
def _piecewise_forecast(
n_steps: int,
breakpoints: tuple[int, ...],
levels: tuple[float, ...],
) -> FloatArray:
"""Return the deterministic piecewise-constant forecast mean."""
if len(breakpoints) != len(levels):
msg = "forecast_breakpoints and forecast_levels must have equal length"
raise ValueError(msg)
if breakpoints[0] != 0:
msg = "forecast_breakpoints must start at 0"
raise ValueError(msg)
out = np.empty(n_steps, dtype=np.float64)
bp = [*breakpoints, n_steps]
for i, level in enumerate(levels):
lo, hi = bp[i], bp[i + 1]
out[lo:hi] = level
return out
def build_timeline(t_total: int, rng: Generator, params: TimelineParams) -> Timeline:
"""Construct the timeline of length ``t_total + 1``."""
n_steps = t_total + 1
fbar = _piecewise_forecast(n_steps, params.forecast_breakpoints, params.forecast_levels)
forecast = fbar + rng.normal(0.0, params.forecast_noise_sigma, size=n_steps)
t = np.arange(n_steps, dtype=np.int64)
voluntary = (t >= params.voluntary_hour).astype(np.int64)
mandatory = (t >= params.mandatory_hour).astype(np.int64)
# time_since_order resets at each escalation (vol turning on, then mand turning on).
sentinel = float(t_total + 1)
time_since_order = np.full(n_steps, sentinel, dtype=np.float64)
last_event = -1
for ti in range(n_steps):
if ti == params.voluntary_hour or ti == params.mandatory_hour:
last_event = ti
if last_event >= 0:
time_since_order[ti] = float(ti - last_event)
return Timeline(
forecast=forecast,
voluntary=voluntary,
mandatory=mandatory,
time_since_order=time_since_order,
)
def local_risk_at(
forecast_t: float, distance: FloatArray, zone_multiplier: FloatArray | None = None
) -> FloatArray:
r"""Compute :math:`\rho_{i,t} = F_t \exp(-d_i / 10)` for one time step.
``zone_multiplier`` (shape (N,) or None) lets scenarios scale risk for
a subset of households (e.g., zone A under targeted messaging).
"""
rho = float(forecast_t) * np.exp(-distance / 10.0)
if zone_multiplier is not None:
rho = rho * zone_multiplier
return rho