# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Tests for the Viterbi and posterior-mode decoders."""
from __future__ import annotations
import numpy as np
from iohmm_evac.diagnostics.decoding import posterior_mode, viterbi
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import (
EmissionFitParams,
FitParameters,
InitialFitParams,
TransitionFitParams,
)
from iohmm_evac.inference.forward_backward import forward_backward
def _deterministic_problem(seed: int = 0) -> tuple[FitParameters, FitData, np.ndarray]:
"""Three-state, near-deterministic-emission problem.
State means are far apart in displacement and observation noise is
small, so MAP decoding should recover the truth path exactly.
"""
rng = np.random.default_rng(seed)
k = 3
f = 1
n = 5
t_plus_1 = 12
alpha = np.array(
[[0.0, -0.5, -3.0], [-3.0, 0.0, -0.5], [-3.0, -3.0, 0.0]],
dtype=np.float64,
)
beta = np.zeros((k, k, f))
beta[0, 1] = np.array([1.0])
beta[1, 2] = np.array([1.0])
emit = EmissionFitParams(
p_departure=np.array([0.05, 0.5, 0.95]),
mu_displacement=np.array([0.0, 5.0, 10.0]),
sigma_displacement=np.array([0.1, 0.1, 0.1]),
lambda_comm=np.array([0.5, 1.5, 3.0]),
)
params = FitParameters(
initial=InitialFitParams(logits=np.array([0.0, -3.0, -3.0])),
transitions=TransitionFitParams(alpha=alpha, beta=beta),
emissions=emit,
)
inputs = np.zeros((n, t_plus_1, f), dtype=np.float64)
inputs[:, :, 0] = np.linspace(0.0, 1.0, t_plus_1)[None, :]
# Hand-construct a state path that's coherent with the transitions.
truth = np.array(
[
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2],
[0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2],
[0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2],
[0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2],
],
dtype=np.int64,
)
departure = (rng.random((n, t_plus_1)) < emit.p_departure[truth]).astype(np.float64)
displacement = rng.normal(emit.mu_displacement[truth], emit.sigma_displacement[truth])
comm = rng.poisson(emit.lambda_comm[truth]).astype(np.float64)
data = FitData(
inputs=inputs,
departure=departure,
displacement=displacement,
comm=comm,
true_states=truth,
)
return params, data, truth
def test_viterbi_returns_correct_shape() -> None:
params, data, _ = _deterministic_problem()
path = viterbi(params, data)
assert path.shape == (data.n, data.t_total + 1)
assert path.dtype == np.int64
def test_viterbi_recovers_truth_under_low_noise() -> None:
params, data, truth = _deterministic_problem(seed=0)
path = viterbi(params, data)
accuracy = float(np.mean(path == truth))
assert accuracy >= 0.99, f"Viterbi accuracy {accuracy:.4f} too low"
def test_posterior_mode_shape_and_consistency() -> None:
params, data, _ = _deterministic_problem(seed=1)
fb = forward_backward(params, data)
mode = posterior_mode(fb.log_gamma)
assert mode.shape == (data.n, data.t_total + 1)
# Mode is per-step argmax: it must be a valid state index.
assert mode.min() >= 0
assert mode.max() < params.emissions.p_departure.shape[0]