# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Multinomial-logit state transitions, vectorized per origin state.
For each origin state ``k``, this module computes the linear logits for every
allowed destination (with the self-loop pinned at 0), softmaxes them, and
samples the next state for every household currently in ``k``. SH is
absorbing.
"""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from numpy.random import Generator
from iohmm_evac.params import TransitionParams
from iohmm_evac.types import FloatArray, IntArray, State
__all__ = [
"StepInputs",
"categorical",
"sample_transitions",
"softmax",
"transition_probs_for_state",
]
@dataclass(frozen=True, slots=True)
class StepInputs:
"""Per-step exogenous and feedback inputs needed by the transition model."""
rho: FloatArray
"""Local risk per household, shape (N,)."""
vol: int
mand: int
tau_norm: float
"""``(1 - τ_t / T) = t / T``; closer to landfall ⇒ closer to 1."""
pi: float
"""Peer-departure share."""
c: float
"""Network congestion."""
def softmax(logits: FloatArray) -> FloatArray:
"""Row-wise numerically stable softmax."""
m = np.max(logits, axis=1, keepdims=True)
e = np.exp(logits - m)
s = np.sum(e, axis=1, keepdims=True)
return np.asarray(e / s, dtype=np.float64)
def categorical(probs: FloatArray, rng: Generator) -> IntArray:
"""Sample one categorical draw per row of ``probs``."""
cum: FloatArray = np.cumsum(probs, axis=1)
# Clamp the final cumulative to 1 to absorb floating-point drift.
cum[:, -1] = 1.0
u = np.asarray(rng.random(size=probs.shape[0]), dtype=np.float64).reshape(-1, 1)
return np.asarray(np.argmax(u < cum, axis=1), dtype=np.int64)
def _logits_from_ua(
inputs: StepInputs, risk: FloatArray, idx: IntArray, params: TransitionParams
) -> FloatArray:
"""Logits for households currently in UA: columns [UA, AW]."""
n = idx.shape[0]
row = params.ua_to_aw
eta_aw = (
row.alpha
+ row.beta_vol * inputs.vol
+ row.beta_mand * inputs.mand
+ row.beta_rho * inputs.rho[idx]
+ row.beta_r * risk[idx]
+ row.beta_tau * inputs.tau_norm
)
return np.column_stack([np.zeros(n), eta_aw])
def _logits_from_aw(
inputs: StepInputs,
risk: FloatArray,
vehicle: FloatArray,
idx: IntArray,
params: TransitionParams,
) -> FloatArray:
"""Logits for households currently in AW: columns [AW, UA, PR]."""
n = idx.shape[0]
eta_ua = np.full(n, params.aw_to_ua.alpha, dtype=np.float64)
pr = params.aw_to_pr
eta_pr = (
pr.alpha
+ pr.beta_mand * inputs.mand
+ pr.beta_rho * inputs.rho[idx]
+ pr.beta_pi * inputs.pi
+ pr.beta_r * risk[idx]
+ pr.beta_v * vehicle[idx]
+ pr.beta_tau * inputs.tau_norm
)
return np.column_stack([np.zeros(n), eta_ua, eta_pr])
def _logits_from_pr(
inputs: StepInputs,
risk: FloatArray,
vehicle: FloatArray,
idx: IntArray,
params: TransitionParams,
) -> FloatArray:
"""Logits for households currently in PR: columns [PR, ER, SH]."""
n = idx.shape[0]
er = params.pr_to_er
eta_er = (
er.alpha
+ er.beta_mand * inputs.mand
+ er.beta_tau * inputs.tau_norm
+ er.beta_negc * (-inputs.c)
+ er.beta_r * risk[idx]
+ er.beta_v * vehicle[idx]
)
sh = params.pr_to_sh
eta_sh = sh.alpha + sh.beta_negr * (-risk[idx]) + sh.beta_negv * (-vehicle[idx])
return np.column_stack([np.zeros(n), eta_er, eta_sh])
def _logits_from_er(
inputs: StepInputs, tir: FloatArray, idx: IntArray, params: TransitionParams
) -> FloatArray:
"""Logits for households currently in ER: columns [ER, SH]."""
n = idx.shape[0]
sh = params.er_to_sh
eta_sh = sh.alpha + sh.beta_tir * tir[idx] + sh.beta_negc * (-inputs.c)
return np.column_stack([np.zeros(n), eta_sh])
_DEST_UA = np.array([State.UA, State.AW], dtype=np.int64)
_DEST_AW = np.array([State.AW, State.UA, State.PR], dtype=np.int64)
_DEST_PR = np.array([State.PR, State.ER, State.SH], dtype=np.int64)
_DEST_ER = np.array([State.ER, State.SH], dtype=np.int64)
def transition_probs_for_state(
origin: State,
inputs: StepInputs,
risk: FloatArray,
vehicle: FloatArray,
tir: FloatArray,
idx: IntArray,
params: TransitionParams,
) -> tuple[FloatArray, IntArray]:
"""Return (probabilities, destination-codes) for households at ``origin``.
Useful for tests that need to inspect the probability matrix directly.
"""
if origin is State.UA:
return softmax(_logits_from_ua(inputs, risk, idx, params)), _DEST_UA
if origin is State.AW:
return softmax(_logits_from_aw(inputs, risk, vehicle, idx, params)), _DEST_AW
if origin is State.PR:
return softmax(_logits_from_pr(inputs, risk, vehicle, idx, params)), _DEST_PR
if origin is State.ER:
return softmax(_logits_from_er(inputs, tir, idx, params)), _DEST_ER
# SH is absorbing.
n = idx.shape[0]
return np.ones((n, 1), dtype=np.float64), np.array([State.SH], dtype=np.int64)
def sample_transitions(
prev_state: IntArray,
inputs: StepInputs,
risk: FloatArray,
vehicle: FloatArray,
tir: FloatArray,
params: TransitionParams,
rng: Generator,
) -> IntArray:
"""Sample the next-step state vector from the multinomial logit."""
new_state = prev_state.copy()
for origin, dests in (
(State.UA, _DEST_UA),
(State.AW, _DEST_AW),
(State.PR, _DEST_PR),
(State.ER, _DEST_ER),
):
idx = np.flatnonzero(prev_state == origin)
if idx.size == 0:
continue
probs, _ = transition_probs_for_state(origin, inputs, risk, vehicle, tir, idx, params)
choice = categorical(probs, rng)
new_state[idx] = dests[choice]
return new_state