src/iohmm_evac/dgp/emissions.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""State-conditional emission model.

Three emission channels, sampled per household per step:

* ``D`` — noisy departure indicator (Bernoulli)
* ``X`` — displacement from home in km (Gaussian / half-Gaussian)
* ``C`` — communication-activity count (Poisson)
"""

from __future__ import annotations

import numpy as np
from numpy.random import Generator

from iohmm_evac.params import EmissionParams
from iohmm_evac.types import BoolArray, FloatArray, IntArray, State

__all__ = ["sample_emissions"]


def _state_lambda(emissions: EmissionParams) -> FloatArray:
    """Per-state Poisson rate vector, indexed by ``State``."""
    return np.array(
        [
            emissions.lambda_ua,
            emissions.lambda_aw,
            emissions.lambda_pr,
            emissions.lambda_er,
            emissions.lambda_sh,
        ],
        dtype=np.float64,
    )


def _state_departure_p(emissions: EmissionParams) -> FloatArray:
    """Per-state Bernoulli rate for the departure indicator."""
    p = np.full(State.n_states(), emissions.p_departure_other, dtype=np.float64)
    p[int(State.ER)] = emissions.p_departure_er
    return p


def sample_emissions(
    state: IntArray,
    evac_path: IntArray,
    tir: FloatArray,
    destination: FloatArray,
    congestion_t: float,
    emissions: EmissionParams,
    rng: Generator,
) -> tuple[BoolArray, FloatArray, IntArray]:
    """Sample (departure, displacement, comm-count) for every household.

    Vectorized over households. ``evac_path`` is encoded as 0=NONE, 1=AWAY,
    2=HOME.
    """
    n = state.shape[0]
    p_vec = _state_departure_p(emissions)
    lam_vec = _state_lambda(emissions)

    # Departure: Bernoulli with state-conditioned rate.
    p_per_hh = p_vec[state]
    departure = rng.random(size=n) < p_per_hh

    # Communication: Poisson with state-conditioned rate.
    comm = rng.poisson(lam_vec[state], size=n).astype(np.int64)

    # Displacement: branch by state and (for SH) by evac_path.
    displacement = np.empty(n, dtype=np.float64)

    idle_sigma = emissions.displacement_idle_sigma
    route_sigma = emissions.displacement_route_sigma
    dest_sigma = emissions.displacement_destination_sigma

    pre_route = (state == State.UA) | (state == State.AW) | (state == State.PR)
    if pre_route.any():
        displacement[pre_route] = np.abs(rng.normal(0.0, idle_sigma, size=int(pre_route.sum())))

    en_route = state == State.ER
    if en_route.any():
        v_eff = emissions.v_free * (1.0 - emissions.congestion_penalty * congestion_t)
        # Distance traveled is min(tir * v_eff, dest_i) plus N(0, route_sigma^2).
        progress = np.minimum(tir[en_route] * v_eff, destination[en_route])
        displacement[en_route] = progress + rng.normal(0.0, route_sigma, size=int(en_route.sum()))

    sheltered = state == State.SH
    if sheltered.any():
        away = sheltered & (evac_path == 1)
        home = sheltered & (evac_path != 1)
        if away.any():
            displacement[away] = rng.normal(destination[away], dest_sigma, size=int(away.sum()))
        if home.any():
            displacement[home] = np.abs(rng.normal(0.0, idle_sigma, size=int(home.sum())))

    return departure, displacement, comm