src/iohmm_evac/inference/forward_backward.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Log-space forward-backward recursions for the IO-HMM.

All quantities are kept in log-space; ``logsumexp`` is the only place where
exponentials are taken explicitly.
"""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np

from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import FitParameters
from iohmm_evac.inference.log_space import LOG_EPS, logsumexp, safe_log
from iohmm_evac.types import FloatArray

__all__ = [
    "ForwardBackwardResult",
    "emission_log_prob",
    "forward_backward",
    "log_transition_matrix",
]


@dataclass(frozen=True, slots=True)
class ForwardBackwardResult:
    """Output of a single forward-backward pass.

    All log-arrays are well-defined (no ``NaN``); forbidden cells in
    ``log_xi`` are at :data:`~iohmm_evac.inference.log_space.LOG_EPS`.
    """

    log_gamma: FloatArray
    """Posterior over states, shape ``(N, T+1, K)``."""
    log_xi: FloatArray
    """Pairwise posterior, shape ``(N, T, K, K)`` for transitions ``t -> t+1``."""
    log_alpha: FloatArray
    """Forward messages, shape ``(N, T+1, K)``."""
    log_beta: FloatArray
    """Backward messages, shape ``(N, T+1, K)``."""
    log_likelihood: FloatArray
    """Per-household log-likelihood, shape ``(N,)``."""


def log_transition_matrix(
    inputs: FloatArray,
    transitions_alpha: FloatArray,
    transitions_beta: FloatArray,
) -> FloatArray:
    """Return ``log A_{kj}(u_{i,t})`` for every ``(i, t, k, j)``.

    The output has shape ``(N, T+1, K, K)``. Forbidden destinations carry
    :data:`~iohmm_evac.inference.log_space.LOG_EPS` (a finite stand-in for
    ``-inf``); the row over ``j`` is log-softmax-normalized.

    Note: although we materialize an ``(N, T+1, K, K)`` array, only slots
    ``t = 1..T`` participate in the recursion (transitions ``t-1 -> t``);
    ``t = 0`` is never read.
    """
    logits = transitions_alpha[None, None, :, :] + np.einsum(
        "ntf,kjf->ntkj", inputs, transitions_beta
    )
    forbidden = ~np.isfinite(transitions_alpha)
    if forbidden.any():
        logits = np.where(forbidden[None, None, :, :], LOG_EPS, logits)
    m = np.max(logits, axis=-1, keepdims=True)
    z = m + np.log(np.exp(logits - m).sum(axis=-1, keepdims=True))
    return np.asarray(logits - z, dtype=np.float64)


def _gaussian_log_pdf(x: FloatArray, mu: float, sigma: float) -> FloatArray:
    """Vectorized univariate Gaussian log-pdf with a finite ``sigma`` floor."""
    var = max(sigma * sigma, 1e-12)
    diff = x - mu
    out = -0.5 * (np.log(2.0 * np.pi * var) + diff * diff / var)
    return np.asarray(out, dtype=np.float64)


def emission_log_prob(data: FitData, params: FitParameters) -> FloatArray:
    """Compute ``log b_k(y_{i,t})`` for every (i, t, k).

    Channels are assumed conditionally independent given the state, so the
    log-pdf is the sum over the three observed channels.
    """
    p = params.emissions.p_departure
    mu = params.emissions.mu_displacement
    sigma = params.emissions.sigma_displacement
    lam = params.emissions.lambda_comm

    n, t_plus_1 = data.departure.shape
    k_states = int(p.shape[0])
    log_b = np.zeros((n, t_plus_1, k_states), dtype=np.float64)
    for k in range(k_states):
        log_p = float(np.log(np.clip(p[k], 1e-12, 1 - 1e-12)))
        log_1mp = float(np.log(np.clip(1.0 - p[k], 1e-12, 1 - 1e-12)))
        bern = data.departure * log_p + (1.0 - data.departure) * log_1mp

        gauss = _gaussian_log_pdf(data.displacement, float(mu[k]), float(sigma[k]))

        lam_k = float(max(lam[k], 1e-9))
        from scipy.special import gammaln

        c = data.comm
        pois = c * np.log(lam_k) - lam_k - np.asarray(gammaln(c + 1.0), dtype=np.float64)

        log_b[:, :, k] = bern + gauss + pois
    return log_b


def _logsumexp_axis(x: FloatArray, axis: int) -> FloatArray:
    """Hand-rolled log-sum-exp along one axis; faster than scipy in tight loops."""
    m = np.max(x, axis=axis, keepdims=True)
    out = np.log(np.exp(x - m).sum(axis=axis)) + np.squeeze(m, axis=axis)
    return np.asarray(out, dtype=np.float64)


def _forward_pass(
    log_b: FloatArray,
    log_initial: FloatArray,
    log_a: FloatArray,
) -> tuple[FloatArray, FloatArray]:
    """Compute log-alpha messages and per-household log-likelihood."""
    n, t_plus_1, k = log_b.shape
    log_alpha = np.empty((n, t_plus_1, k), dtype=np.float64)
    log_alpha[:, 0, :] = log_initial[None, :] + log_b[:, 0, :]
    for t in range(1, t_plus_1):
        # log_alpha[t, k] = log_b[t, k] + logsumexp_j(log_alpha[t-1, j] + log_a[t, j, k])
        prev = log_alpha[:, t - 1, :, None]  # (N, K, 1)
        trans = log_a[:, t, :, :]  # (N, K, K) — A_{j,k}
        log_alpha[:, t, :] = log_b[:, t, :] + _logsumexp_axis(prev + trans, axis=1)
    log_likelihood = logsumexp(log_alpha[:, t_plus_1 - 1, :], axis=1)
    return log_alpha, np.asarray(log_likelihood, dtype=np.float64)


def _backward_pass(log_b: FloatArray, log_a: FloatArray) -> FloatArray:
    """Compute log-beta messages."""
    n, t_plus_1, k = log_b.shape
    log_beta = np.zeros((n, t_plus_1, k), dtype=np.float64)
    for t in range(t_plus_1 - 2, -1, -1):
        # log_beta[t, k] = logsumexp_j(log_a[t+1, k, j] + log_b[t+1, j] + log_beta[t+1, j])
        nxt = log_b[:, t + 1, :] + log_beta[:, t + 1, :]  # (N, K)
        log_beta[:, t, :] = _logsumexp_axis(log_a[:, t + 1, :, :] + nxt[:, None, :], axis=2)
    return log_beta


def forward_backward(params: FitParameters, data: FitData) -> ForwardBackwardResult:
    """Run forward-backward and return posteriors plus log-likelihood."""
    log_a = log_transition_matrix(data.inputs, params.transitions.alpha, params.transitions.beta)
    log_b = emission_log_prob(data, params)
    log_initial = safe_log(params.initial.probs())

    log_alpha, log_likelihood = _forward_pass(log_b, log_initial, log_a)
    log_beta = _backward_pass(log_b, log_a)

    log_gamma = log_alpha + log_beta - log_likelihood[:, None, None]

    n, t_plus_1, k = log_b.shape
    t_total = t_plus_1 - 1
    log_xi = np.zeros((n, t_total, k, k), dtype=np.float64)
    if t_total > 0:
        a_kj = log_a[:, 1:, :, :]  # transitions for steps 1..T
        b_next = log_b[:, 1:, :]  # emissions at t+1
        beta_next = log_beta[:, 1:, :]  # backwards at t+1
        alpha_prev = log_alpha[:, :-1, :]  # forward at t
        log_xi = (
            alpha_prev[:, :, :, None]
            + a_kj
            + b_next[:, :, None, :]
            + beta_next[:, :, None, :]
            - log_likelihood[:, None, None, None]
        )

    return ForwardBackwardResult(
        log_gamma=np.asarray(log_gamma, dtype=np.float64),
        log_xi=np.asarray(log_xi, dtype=np.float64),
        log_alpha=log_alpha,
        log_beta=log_beta,
        log_likelihood=log_likelihood,
    )