# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Parameter dataclasses for IO-HMM inference, decoupled from the DGP side.
The IO-HMM views the world as ``K`` latent states, an input vector ``u_t``
of dimension ``F``, and three emission channels (Bernoulli ``D``, Gaussian
``X``, Poisson ``C``). Transitions are multinomial-logit:
A_kj(u) = exp(alpha_kj + beta_kj^T u) / sum_l exp(alpha_kl + beta_kl^T u)
with ``(alpha_kk, beta_kk) = (0, 0)`` for identifiability and
``alpha_kj = -inf`` for forbidden destinations.
Self-loops and forbidden-destination cells in the ``alpha`` and ``beta``
arrays are pinned to fixed values (``0`` and a finite sentinel respectively
for self; ``-inf`` and zeros for forbidden) and never enter the optimization
parameter vector — see ``learnable_indices``.
"""
from __future__ import annotations
from dataclasses import dataclass, field, replace
import numpy as np
from iohmm_evac.params import (
EmissionParams as DGPEmissionParams,
)
from iohmm_evac.params import (
PopulationParams as DGPPopulationParams,
)
from iohmm_evac.params import TransitionParams as DGPTransitionParams
from iohmm_evac.types import BoolArray, FloatArray, State
__all__ = [
"ALLOWED_TRANSITIONS",
"FEATURE_NAMES",
"EmissionFitParams",
"FitParameters",
"InitialFitParams",
"K",
"TransitionFitParams",
"allowed_mask",
"dgp_truth_to_fit_init",
"learnable_indices",
]
K: int = 5
"""Number of latent states. Matches the DGP's ``State`` enum."""
FEATURE_NAMES: tuple[str, ...] = ("vol", "mand", "rho", "r", "v", "tau")
"""Exogenous IO-HMM input features, in order. ``F = len(FEATURE_NAMES)``.
Endogenous DGP features (``pi``, ``c``, ``tir``) are intentionally absent —
they depend on the latent state path and so cannot enter the inputs of an
inference model that does not know that path. This mis-specification is a
deliberate design choice; see ``docs/inference.md``.
"""
F: int = len(FEATURE_NAMES)
_ALLOWED_PAIRS: tuple[tuple[State, State], ...] = (
(State.UA, State.AW),
(State.AW, State.UA),
(State.AW, State.PR),
(State.PR, State.ER),
(State.PR, State.SH),
(State.ER, State.SH),
)
def allowed_mask() -> BoolArray:
"""Return the K×K boolean mask of allowed (incl. self-loop) transitions."""
mask = np.zeros((K, K), dtype=bool)
for k in range(K):
mask[k, k] = True
for src, dst in _ALLOWED_PAIRS:
mask[int(src), int(dst)] = True
return mask
ALLOWED_TRANSITIONS: BoolArray = allowed_mask()
"""Module-level constant. Index ``[k, j]`` is True iff ``k -> j`` is allowed."""
def learnable_indices() -> tuple[BoolArray, BoolArray]:
"""Return (allowed_non_self_mask, self_mask).
``allowed_non_self_mask[k, j]`` is True for cells that participate in
L-BFGS optimization. ``self_mask[k, j]`` is True iff ``k == j``.
"""
self_mask = np.eye(K, dtype=bool)
learnable = ALLOWED_TRANSITIONS & ~self_mask
return learnable, self_mask
@dataclass(frozen=True, slots=True)
class InitialFitParams:
"""Initial-state distribution.
``logits`` are unnormalized scores; the actual distribution is the
softmax. Free parameters: ``K - 1`` (one anchored to 0).
"""
logits: FloatArray
def probs(self) -> FloatArray:
"""Return the normalized initial-state distribution."""
m = np.max(self.logits)
e = np.exp(self.logits - m)
out = e / e.sum()
return np.asarray(out, dtype=np.float64)
@dataclass(frozen=True, slots=True)
class TransitionFitParams:
"""K×K transition logit parameters under the IO-HMM input ``u``."""
alpha: FloatArray # shape (K, K); self-loops 0, forbidden -inf
beta: FloatArray # shape (K, K, F); self-loops zeros, forbidden zeros
@dataclass(frozen=True, slots=True)
class EmissionFitParams:
"""State-conditional emission parameters."""
p_departure: FloatArray # shape (K,) Bernoulli rates
mu_displacement: FloatArray # shape (K,)
sigma_displacement: FloatArray # shape (K,) std-devs (>= sigma_floor)
lambda_comm: FloatArray # shape (K,) Poisson rates
sigma_floor: float = 1e-2
@dataclass(frozen=True, slots=True)
class FitParameters:
"""Top-level IO-HMM parameter bundle."""
initial: InitialFitParams
transitions: TransitionFitParams
emissions: EmissionFitParams
feature_names: tuple[str, ...] = field(default=FEATURE_NAMES)
def _trans_pair_to_io(
dgp_alpha: float,
dgp_betas: dict[str, float],
) -> tuple[float, FloatArray]:
"""Translate one DGP TransitionRow into IO-HMM (alpha, beta_vec).
The DGP uses ``beta_negc * (-c)``, ``beta_negr * (-r)``,
``beta_negv * (-v)`` (note: negation, not ``1 - x``). The IO-HMM input
vector uses ``r`` and ``v`` directly and *omits* ``c`` (endogenous).
``beta_negc`` therefore drops out of the IO-HMM image entirely (the IO
vector has no ``c``); ``beta_negr`` and ``beta_negv`` are folded by sign
into the corresponding ``r`` and ``v`` slots.
"""
beta_vec = np.zeros(F, dtype=np.float64)
idx = {name: i for i, name in enumerate(FEATURE_NAMES)}
beta_vec[idx["vol"]] = dgp_betas.get("beta_vol", 0.0)
beta_vec[idx["mand"]] = dgp_betas.get("beta_mand", 0.0)
beta_vec[idx["rho"]] = dgp_betas.get("beta_rho", 0.0)
beta_vec[idx["r"]] = dgp_betas.get("beta_r", 0.0) - dgp_betas.get("beta_negr", 0.0)
beta_vec[idx["v"]] = dgp_betas.get("beta_v", 0.0) - dgp_betas.get("beta_negv", 0.0)
beta_vec[idx["tau"]] = dgp_betas.get("beta_tau", 0.0)
return dgp_alpha, beta_vec
def _row_dict(row: object) -> dict[str, float]:
return {
f: float(getattr(row, f))
for f in (
"alpha",
"beta_vol",
"beta_mand",
"beta_rho",
"beta_pi",
"beta_r",
"beta_v",
"beta_tau",
"beta_negc",
"beta_negr",
"beta_negv",
"beta_tir",
)
}
def _dgp_displacement_moments(
emissions: DGPEmissionParams, population: DGPPopulationParams
) -> tuple[FloatArray, FloatArray]:
"""Return DGP-implied per-state displacement (μ, σ) for the IO-HMM init.
Derivations:
* **Half-normal** (UA, AW, PR — ``|N(0, σ²)|`` with
``σ = displacement_idle_sigma``):
``E[X] = σ √(2/π)``, ``Var(X) = σ²(1 - 2/π)``.
* **Mid-evacuation** (ER): mid-evacuation approximation,
``μ = (dest_lo + dest_hi) / 4``,
``σ = √((dest_hi - dest_lo)² / 12 + 5²)``. The ``+5²`` term is a
conservative inflation accounting for ``tir``/``c_t`` variability the
IO-HMM does not see.
* **Uniform** (SH — ``Uniform(dest_lo, dest_hi)`` for the dominant
``away`` mode; the small ``home`` component is folded into the
same Gaussian here as a first approximation):
``E[X] = (dest_lo + dest_hi) / 2``,
``Var(X) = (dest_hi - dest_lo)² / 12``.
"""
sigma_idle = float(emissions.displacement_idle_sigma)
mu_half = sigma_idle * np.sqrt(2.0 / np.pi)
sigma_half = sigma_idle * np.sqrt(1.0 - 2.0 / np.pi)
dest_lo = float(population.dest_lo)
dest_hi = float(population.dest_hi)
dest_span = dest_hi - dest_lo
mu_er = (dest_lo + dest_hi) / 4.0
sigma_er = float(np.sqrt(dest_span * dest_span / 12.0 + 25.0))
mu_sh = (dest_lo + dest_hi) / 2.0
sigma_sh = float(np.sqrt(dest_span * dest_span / 12.0))
mu = np.array(
[mu_half, mu_half, mu_half, mu_er, mu_sh],
dtype=np.float64,
)
sigma = np.array(
[sigma_half, sigma_half, sigma_half, sigma_er, sigma_sh],
dtype=np.float64,
)
return mu, sigma
def dgp_truth_to_fit_init(
transitions: DGPTransitionParams,
emissions: DGPEmissionParams,
population: DGPPopulationParams | None = None,
) -> FitParameters:
"""Build a :class:`FitParameters` initialized at the DGP's true values.
Used for the ``--init truth`` testing path and as a reference point for
parameter recovery diagnostics. Endogenous-feedback DGP coefficients
(``beta_pi``, ``beta_negc``, ``beta_tir``) have no IO-HMM image and are
discarded by this projection.
``population`` (default :class:`DGPPopulationParams`) is read for
``dest_lo`` / ``dest_hi`` to compute DGP-implied displacement moments —
see :func:`_dgp_displacement_moments`.
"""
pop = population if population is not None else DGPPopulationParams()
alpha = np.full((K, K), -np.inf, dtype=np.float64)
beta = np.zeros((K, K, F), dtype=np.float64)
np.fill_diagonal(alpha, 0.0)
from iohmm_evac.params import TransitionRow as _TransitionRow
pairs: list[tuple[int, int, _TransitionRow]] = [
(int(State.UA), int(State.AW), transitions.ua_to_aw),
(int(State.AW), int(State.UA), transitions.aw_to_ua),
(int(State.AW), int(State.PR), transitions.aw_to_pr),
(int(State.PR), int(State.ER), transitions.pr_to_er),
(int(State.PR), int(State.SH), transitions.pr_to_sh),
(int(State.ER), int(State.SH), transitions.er_to_sh),
]
for k, j, row in pairs:
a, b = _trans_pair_to_io(float(row.alpha), _row_dict(row))
alpha[k, j] = a
beta[k, j] = b
initial_logits = np.full(K, -10.0, dtype=np.float64)
initial_logits[int(State.UA)] = 0.0
init_params = InitialFitParams(logits=initial_logits)
p_departure = np.full(K, emissions.p_departure_other, dtype=np.float64)
p_departure[int(State.ER)] = emissions.p_departure_er
mu_displacement, sigma_displacement = _dgp_displacement_moments(emissions, pop)
lambda_comm = np.array(
[
emissions.lambda_ua,
emissions.lambda_aw,
emissions.lambda_pr,
emissions.lambda_er,
emissions.lambda_sh,
],
dtype=np.float64,
)
emit_params = EmissionFitParams(
p_departure=p_departure,
mu_displacement=mu_displacement,
sigma_displacement=sigma_displacement,
lambda_comm=lambda_comm,
)
trans_params = TransitionFitParams(alpha=alpha, beta=beta)
return FitParameters(initial=init_params, transitions=trans_params, emissions=emit_params)
def with_initial(params: FitParameters, initial: InitialFitParams) -> FitParameters:
"""Return ``params`` with a new initial distribution; helper for E/M-step."""
return replace(params, initial=initial)