# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""State-conditional emission model.
Three emission channels, sampled per household per step:
* ``D`` — noisy departure indicator (Bernoulli)
* ``X`` — displacement from home in km (Gaussian / half-Gaussian)
* ``C`` — communication-activity count (Poisson)
"""
from __future__ import annotations
import numpy as np
from numpy.random import Generator
from iohmm_evac.params import EmissionParams
from iohmm_evac.types import BoolArray, FloatArray, IntArray, State
__all__ = ["sample_emissions"]
def _state_lambda(emissions: EmissionParams) -> FloatArray:
"""Per-state Poisson rate vector, indexed by ``State``."""
return np.array(
[
emissions.lambda_ua,
emissions.lambda_aw,
emissions.lambda_pr,
emissions.lambda_er,
emissions.lambda_sh,
],
dtype=np.float64,
)
def _state_departure_p(emissions: EmissionParams) -> FloatArray:
"""Per-state Bernoulli rate for the departure indicator."""
p = np.full(State.n_states(), emissions.p_departure_other, dtype=np.float64)
p[int(State.ER)] = emissions.p_departure_er
return p
def sample_emissions(
state: IntArray,
evac_path: IntArray,
tir: FloatArray,
destination: FloatArray,
congestion_t: float,
emissions: EmissionParams,
rng: Generator,
) -> tuple[BoolArray, FloatArray, IntArray]:
"""Sample (departure, displacement, comm-count) for every household.
Vectorized over households. ``evac_path`` is encoded as 0=NONE, 1=AWAY,
2=HOME.
"""
n = state.shape[0]
p_vec = _state_departure_p(emissions)
lam_vec = _state_lambda(emissions)
# Departure: Bernoulli with state-conditioned rate.
p_per_hh = p_vec[state]
departure = rng.random(size=n) < p_per_hh
# Communication: Poisson with state-conditioned rate.
comm = rng.poisson(lam_vec[state], size=n).astype(np.int64)
# Displacement: branch by state and (for SH) by evac_path.
displacement = np.empty(n, dtype=np.float64)
idle_sigma = emissions.displacement_idle_sigma
route_sigma = emissions.displacement_route_sigma
dest_sigma = emissions.displacement_destination_sigma
pre_route = (state == State.UA) | (state == State.AW) | (state == State.PR)
if pre_route.any():
displacement[pre_route] = np.abs(rng.normal(0.0, idle_sigma, size=int(pre_route.sum())))
en_route = state == State.ER
if en_route.any():
v_eff = emissions.v_free * (1.0 - emissions.congestion_penalty * congestion_t)
# Distance traveled is min(tir * v_eff, dest_i) plus N(0, route_sigma^2).
progress = np.minimum(tir[en_route] * v_eff, destination[en_route])
displacement[en_route] = progress + rng.normal(0.0, route_sigma, size=int(en_route.sum()))
sheltered = state == State.SH
if sheltered.any():
away = sheltered & (evac_path == 1)
home = sheltered & (evac_path != 1)
if away.any():
displacement[away] = rng.normal(destination[away], dest_sigma, size=int(away.sum()))
if home.any():
displacement[home] = np.abs(rng.normal(0.0, idle_sigma, size=int(home.sum())))
return departure, displacement, comm