src/iohmm_evac/dgp/timeline.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Build the exogenous forcing timeline (forecast, warning orders, etc.)."""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np
from numpy.random import Generator

from iohmm_evac.params import TimelineParams
from iohmm_evac.types import FloatArray, IntArray

__all__ = ["Timeline", "build_timeline", "local_risk_at"]


@dataclass(frozen=True, slots=True)
class Timeline:
    """Hourly exogenous inputs of length ``T+1``."""

    forecast: FloatArray
    """Noisy forecast intensity, shape (T+1,)."""

    voluntary: IntArray
    """Voluntary-evacuation indicator, shape (T+1,)."""

    mandatory: IntArray
    """Mandatory-evacuation indicator, shape (T+1,)."""

    time_since_order: FloatArray
    """Hours since the most recent escalation in (vol, mand), shape (T+1,).

    At ``t=0`` no order has occurred; we use ``T+1`` as a sentinel "very large"
    value so that downstream features expecting a small recency value treat
    early time steps as effectively no-order.
    """


def _piecewise_forecast(
    n_steps: int,
    breakpoints: tuple[int, ...],
    levels: tuple[float, ...],
) -> FloatArray:
    """Return the deterministic piecewise-constant forecast mean."""
    if len(breakpoints) != len(levels):
        msg = "forecast_breakpoints and forecast_levels must have equal length"
        raise ValueError(msg)
    if breakpoints[0] != 0:
        msg = "forecast_breakpoints must start at 0"
        raise ValueError(msg)
    out = np.empty(n_steps, dtype=np.float64)
    bp = [*breakpoints, n_steps]
    for i, level in enumerate(levels):
        lo, hi = bp[i], bp[i + 1]
        out[lo:hi] = level
    return out


def build_timeline(t_total: int, rng: Generator, params: TimelineParams) -> Timeline:
    """Construct the timeline of length ``t_total + 1``."""
    n_steps = t_total + 1
    fbar = _piecewise_forecast(n_steps, params.forecast_breakpoints, params.forecast_levels)
    forecast = fbar + rng.normal(0.0, params.forecast_noise_sigma, size=n_steps)

    t = np.arange(n_steps, dtype=np.int64)
    voluntary = (t >= params.voluntary_hour).astype(np.int64)
    mandatory = (t >= params.mandatory_hour).astype(np.int64)

    # time_since_order resets at each escalation (vol turning on, then mand turning on).
    sentinel = float(t_total + 1)
    time_since_order = np.full(n_steps, sentinel, dtype=np.float64)
    last_event = -1
    for ti in range(n_steps):
        if ti == params.voluntary_hour or ti == params.mandatory_hour:
            last_event = ti
        if last_event >= 0:
            time_since_order[ti] = float(ti - last_event)

    return Timeline(
        forecast=forecast,
        voluntary=voluntary,
        mandatory=mandatory,
        time_since_order=time_since_order,
    )


def local_risk_at(
    forecast_t: float, distance: FloatArray, zone_multiplier: FloatArray | None = None
) -> FloatArray:
    r"""Compute :math:`\rho_{i,t} = F_t \exp(-d_i / 10)` for one time step.

    ``zone_multiplier`` (shape (N,) or None) lets scenarios scale risk for
    a subset of households (e.g., zone A under targeted messaging).
    """
    rho = float(forecast_t) * np.exp(-distance / 10.0)
    if zone_multiplier is not None:
        rho = rho * zone_multiplier
    return rho