# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Posterior decoding: Viterbi MAP path and per-step posterior mode."""
from __future__ import annotations
import numpy as np
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import FitParameters
from iohmm_evac.inference.forward_backward import (
emission_log_prob,
log_transition_matrix,
)
from iohmm_evac.inference.log_space import safe_log
from iohmm_evac.types import FloatArray, IntArray
__all__ = ["posterior_mode", "viterbi"]
def posterior_mode(log_gamma: FloatArray) -> IntArray:
"""Per-step argmax of the posterior (a.k.a. marginal MAP path)."""
return np.asarray(np.argmax(log_gamma, axis=2), dtype=np.int64)
def viterbi(params: FitParameters, data: FitData) -> IntArray:
"""Joint MAP state path via the Viterbi recursion in log-space.
Returns an ``(N, T+1)`` int64 array.
"""
log_a = log_transition_matrix(data.inputs, params.transitions.alpha, params.transitions.beta)
log_b = emission_log_prob(data, params)
log_initial = safe_log(params.initial.probs())
n, t_plus_1, k = log_b.shape
delta = np.empty((n, t_plus_1, k), dtype=np.float64)
psi = np.empty((n, t_plus_1, k), dtype=np.int64)
delta[:, 0, :] = log_initial[None, :] + log_b[:, 0, :]
psi[:, 0, :] = 0
for t in range(1, t_plus_1):
# candidate[i, j_prev, k_curr] = delta[t-1, j_prev] + log_a[t, j_prev, k_curr]
candidate = delta[:, t - 1, :, None] + log_a[:, t, :, :]
psi[:, t, :] = np.asarray(np.argmax(candidate, axis=1), dtype=np.int64)
delta[:, t, :] = log_b[:, t, :] + np.max(candidate, axis=1)
path = np.empty((n, t_plus_1), dtype=np.int64)
path[:, t_plus_1 - 1] = np.argmax(delta[:, t_plus_1 - 1, :], axis=1)
for t in range(t_plus_1 - 2, -1, -1):
path[:, t] = psi[np.arange(n), t + 1, path[:, t + 1]]
return path