docs/inference.md

# IO-HMM inference

This page is the methodological writeup for the inference half of the
project: forward-backward, EM, and the L-BFGS-B M-step for the
multinomial-logit transition rows. It mirrors the code in
[`src/iohmm_evac/inference/`](../src/iohmm_evac/inference) and uses the same
notation. Build 2 introduces all of this; before Build 2 the project only
generated data, never fit it.

## Notation

Per household ``i`` and time ``t`` (with ``t = 0..T``):

- ``S_{i,t} ∈ {0, …, K-1}`` — latent state.
- ``y_{i,t} = (D_{i,t}, X_{i,t}, C_{i,t})`` — three observed emission channels:
  Bernoulli departure indicator, Gaussian displacement, Poisson
  communication count.
- ``u_{i,t} ∈ ℝ^F`` — exogenous input vector. The IO-HMM uses
  ``u = [vol, mand, ρ, r, v, τ]`` (six features); see
  [`fit_params.FEATURE_NAMES`](../src/iohmm_evac/inference/fit_params.py).

The IO-HMM parameters are
``θ = (θ^init, θ^trans, θ^emit)``:

- **Initial.** ``π_k = P(S_{i,0} = k)``, encoded as logits and softmax-normalized.
- **Transitions.** Multinomial logit:

  ```
  A_kj(u) = exp(α_kj + β_kj·u) / Σ_l exp(α_kl + β_kl·u)
  ```

  with the identifiability normalization ``(α_{k,k}, β_{k,k}) ≡ (0, 0)`` and
  ``α_{k,j} = -∞`` for forbidden transitions (encoded with the
  ``ALLOWED_TRANSITIONS`` mask).
- **Emissions** (state-conditional, factorized given the state):
  - ``D_{i,t} | S_{i,t}=k ~ Bernoulli(p_k)``
  - ``X_{i,t} | S_{i,t}=k ~ N(μ_k, σ_k²)``
  - ``C_{i,t} | S_{i,t}=k ~ Poisson(λ_k)``

## Forward-backward (log-space)

All recursions live in log-space; the only place exponentials are taken is
inside `scipy.special.logsumexp`.

**Forward.**

```
log α_{i,0}(k) = log π_k + log b_k(y_{i,0})
log α_{i,t}(k) = log b_k(y_{i,t}) + logsumexp_j[log α_{i,t-1}(j) + log A_{j,k}(u_{i,t})]
```

**Backward.**

```
log β_{i,T}(k) = 0
log β_{i,t}(k) = logsumexp_j[log A_{k,j}(u_{i,t+1}) + log b_j(y_{i,t+1}) + log β_{i,t+1}(j)]
```

**Posteriors.**

```
log γ_{i,t}(k)   = log α_{i,t}(k) + log β_{i,t}(k) - L_i
log ξ_{i,t}(k,j) = log α_{i,t}(k) + log A_{k,j}(u_{i,t+1}) + log b_j(y_{i,t+1}) + log β_{i,t+1}(j) - L_i
```

with the per-household log-likelihood
``L_i = logsumexp_k log α_{i,T}(k)``.

The implementation precomputes ``log A_{k,j}(u_{i,t})`` as an
``(N, T+1, K, K)`` array; for the production data this is
``10000·121·5·5·8 B ≈ 240 MB``, which fits comfortably in RAM. Forbidden
cells carry ``LOG_EPS = -1e30`` (a finite stand-in for ``-∞``) so the
log-softmax is well-defined.

## M-step

The M-step has a closed form for the initial distribution and the
emission parameters; the transition rows are fit by L-BFGS-B because the
multinomial logit's MLE has no closed form.

**Initial.**

```
π̂_k = (Σ_i γ_{i,0}(k)) / Σ_{i,l} γ_{i,0}(l)
```

**Emissions** (weighted MLEs).

```
p̂_k = Σ_{i,t} γ_{i,t}(k) D_{i,t} / Σ_{i,t} γ_{i,t}(k)
λ̂_k = Σ_{i,t} γ_{i,t}(k) C_{i,t} / Σ_{i,t} γ_{i,t}(k)
μ̂_k = Σ_{i,t} γ_{i,t}(k) X_{i,t} / Σ_{i,t} γ_{i,t}(k)
σ̂²_k = Σ_{i,t} γ_{i,t}(k) (X_{i,t} - μ̂_k)² / Σ_{i,t} γ_{i,t}(k)
```

``σ̂²_k`` is floored at ``sigma_floor² = (1e-2)²`` to prevent variance
collapse; ``p̂_k`` is clipped to ``[1e-6, 1 - 1e-6]`` and ``λ̂_k`` is floored
at ``1e-6`` for the same reason.

**Transitions.** For each origin ``k``, the local objective is

```
Q_k(θ_k) = Σ_i Σ_{t=2..T} Σ_j ξ_{i,t-1}(k, j) · log A_{k,j}(u_{i,t}; θ_k)
```

and the gradient w.r.t. ``α_{k,j}`` (resp. ``β_{k,j,f}``) is

```
∂Q_k/∂α_{k,j}     = Σ_i Σ_{t=2..T} [ξ_{i,t-1}(k,j) - γ_{i,t-1}(k) · A_{k,j}(u_{i,t})]
∂Q_k/∂β_{k,j,f}   = Σ_i Σ_{t=2..T} [ξ_{i,t-1}(k,j) - γ_{i,t-1}(k) · A_{k,j}(u_{i,t})] · u_{i,t,f}
```

Both sums vectorize cleanly. ``-Q_k`` and ``-∇Q_k`` are passed to
`scipy.optimize.minimize(method="L-BFGS-B", jac=True)` separately for each
origin row. Forbidden destinations and the self-loop are *excluded from
the parameter vector entirely*: only the learnable cells participate in
the L-BFGS update. After optimization the row is reconstructed with the
fixed pins (``α_{k,k} = 0``, ``α_{k,j} = -∞`` for forbidden ``j``).

## EM loop

```
Initialize θ
prev_ll = -∞
for iter in 1..max_iter:
    γ, ξ, ll = forward_backward(θ)
    if (ll - prev_ll) / max(|ll|, 1) < tol: break
    if ll < prev_ll - 1e-6:                # bug guard
        warn, restore prev θ, stop
    prev_ll = ll, last_good = θ
    θ.π        ← closed-form update
    for k:
        θ.{α_k, β_k} ← L-BFGS-B step
    θ.emit     ← closed-form weighted MLE
```

Defaults: ``max_iter = 200``, ``tol = 1e-5``. The implementation in
[`em.py`](../src/iohmm_evac/inference/em.py) keeps the previous
parameters around so that a (theoretically impossible) decrease in
log-likelihood between iterations is recovered from.

## Initialization strategies

[`initialization.py`](../src/iohmm_evac/inference/initialization.py)
exposes three:

- **`random`** — a generic prior (self-loops at logit 0; non-self learnable
  entries at moderately negative values plus jitter; emissions seeded from
  heuristic location/scales).
- **`kmeans`** — a mini K-means on ``(D, X, C)`` triples seeds the emission
  means; transitions and the initial distribution come from `random`.
- **`truth`** — initialize at the DGP's true parameters projected through
  ``dgp_truth_to_fit_init``. Used by recovery tests and by the
  ``--init truth`` CLI path.

## Mis-specification on the production DGP

The production DGP in
[`iohmm_evac.dgp`](../src/iohmm_evac/dgp) uses three endogenous-feedback
features that the IO-HMM does not model:

- ``π_t`` (peer-departure share at the population level)
- ``c_t`` (network congestion)
- ``tir_{i,t}`` (per-household time-in-ER)

The IO-HMM cannot include these in ``u`` because they depend on the
latent state path that the inference is trying to recover. We therefore
*intentionally drop them* from the IO-HMM's input vector and live with
the resulting mis-specification. On the production simulation the fit
will not be exact; on the clean DGP in
[`tests/_clean_dgp.py`](../tests/_clean_dgp.py), which uses only
exogenous inputs, the fit recovers states (≥85% accuracy) and parameters
(β RMSE ≤ 0.5, μ RMSE ≤ 0.5).

This is a feature of the chapter narrative, not a bug. See
[`docs/diagnostics.md`](diagnostics.md) for what the recovery diagnostics
do with the imperfect fit.

## Future work (Build 3 and beyond)

- **Standard errors.** No observed-information or Louis's-identity
  uncertainty quantification yet — just point estimates. Bootstrap-based
  uncertainty is on the Build 3 roadmap.
- **Streaming / online inference.** Out of scope for the chapter.
- **Variational alternatives to EM.** Out of scope.
- **Inferring structural zeros from data.** We pin them at the DGP-known
  forbidden mask; learning them is a separate research question.