src/iohmm_evac/diagnostics/decoding.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Posterior decoding: Viterbi MAP path and per-step posterior mode."""

from __future__ import annotations

import numpy as np

from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import FitParameters
from iohmm_evac.inference.forward_backward import (
    emission_log_prob,
    log_transition_matrix,
)
from iohmm_evac.inference.log_space import safe_log
from iohmm_evac.types import FloatArray, IntArray

__all__ = ["posterior_mode", "viterbi"]


def posterior_mode(log_gamma: FloatArray) -> IntArray:
    """Per-step argmax of the posterior (a.k.a. marginal MAP path)."""
    return np.asarray(np.argmax(log_gamma, axis=2), dtype=np.int64)


def viterbi(params: FitParameters, data: FitData) -> IntArray:
    """Joint MAP state path via the Viterbi recursion in log-space.

    Returns an ``(N, T+1)`` int64 array.
    """
    log_a = log_transition_matrix(data.inputs, params.transitions.alpha, params.transitions.beta)
    log_b = emission_log_prob(data, params)
    log_initial = safe_log(params.initial.probs())

    n, t_plus_1, k = log_b.shape
    delta = np.empty((n, t_plus_1, k), dtype=np.float64)
    psi = np.empty((n, t_plus_1, k), dtype=np.int64)

    delta[:, 0, :] = log_initial[None, :] + log_b[:, 0, :]
    psi[:, 0, :] = 0
    for t in range(1, t_plus_1):
        # candidate[i, j_prev, k_curr] = delta[t-1, j_prev] + log_a[t, j_prev, k_curr]
        candidate = delta[:, t - 1, :, None] + log_a[:, t, :, :]
        psi[:, t, :] = np.asarray(np.argmax(candidate, axis=1), dtype=np.int64)
        delta[:, t, :] = log_b[:, t, :] + np.max(candidate, axis=1)

    path = np.empty((n, t_plus_1), dtype=np.int64)
    path[:, t_plus_1 - 1] = np.argmax(delta[:, t_plus_1 - 1, :], axis=1)
    for t in range(t_plus_1 - 2, -1, -1):
        path[:, t] = psi[np.arange(n), t + 1, path[:, t + 1]]
    return path