# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Numerical-correctness tests for the log-space forward-backward."""
from __future__ import annotations
import itertools
import numpy as np
import pytest
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import (
EmissionFitParams,
FitParameters,
InitialFitParams,
TransitionFitParams,
)
from iohmm_evac.inference.forward_backward import (
emission_log_prob,
forward_backward,
log_transition_matrix,
)
def _make_tiny_problem() -> tuple[FitParameters, FitData]:
"""A K=2, T=2 (so T+1=3 timesteps), N=1 problem for brute-force checking.
Two states, two-dim feature vector. Emissions: Bernoulli + Gaussian +
Poisson with deliberately distinct per-state parameters so the
posteriors aren't trivial.
"""
alpha = np.array([[0.0, -0.5], [0.3, 0.0]], dtype=np.float64)
beta = np.array(
[
[[0.0, 0.0], [0.4, -0.1]],
[[-0.2, 0.3], [0.0, 0.0]],
],
dtype=np.float64,
)
initial_logits = np.array([0.4, -0.4], dtype=np.float64)
emit = EmissionFitParams(
p_departure=np.array([0.2, 0.8], dtype=np.float64),
mu_displacement=np.array([1.0, 4.0], dtype=np.float64),
sigma_displacement=np.array([1.0, 1.5], dtype=np.float64),
lambda_comm=np.array([1.0, 3.0], dtype=np.float64),
)
params = FitParameters(
initial=InitialFitParams(logits=initial_logits),
transitions=TransitionFitParams(alpha=alpha, beta=beta),
emissions=emit,
feature_names=("x", "y"),
)
inputs = np.array([[[0.5, 0.0], [-0.4, 0.7], [0.1, -0.2]]], dtype=np.float64)
departure = np.array([[1.0, 0.0, 1.0]], dtype=np.float64)
displacement = np.array([[1.2, 3.6, 2.5]], dtype=np.float64)
comm = np.array([[1.0, 3.0, 2.0]], dtype=np.float64)
data = FitData(
inputs=inputs,
departure=departure,
displacement=displacement,
comm=comm,
true_states=None,
)
# _bf_likelihood uses K=2 implicitly; the FitData K still equals the
# global ``K`` used elsewhere, but we can drop it via a custom test
# that ignores the K mismatch — see _brute_force.
return params, data
def _brute_force_log_likelihood(params: FitParameters, data: FitData, k: int) -> float:
"""Sum joint likelihoods over every state path for the tiny problem."""
log_a = log_transition_matrix(data.inputs, params.transitions.alpha, params.transitions.beta)
log_b = emission_log_prob(data, params)
log_initial = np.log(params.initial.probs())
n, t_plus_1, _ = log_b.shape
assert n == 1
total = -np.inf
for path in itertools.product(range(k), repeat=t_plus_1):
ll = log_initial[path[0]] + log_b[0, 0, path[0]]
for t in range(1, t_plus_1):
ll += log_a[0, t, path[t - 1], path[t]] + log_b[0, t, path[t]]
total = float(np.logaddexp(total, ll))
return total
def test_brute_force_equivalence() -> None:
params, data = _make_tiny_problem()
fb = forward_backward(params, data)
bf_ll = _brute_force_log_likelihood(params, data, k=2)
fb_ll = float(fb.log_likelihood[0])
assert fb_ll == pytest.approx(bf_ll, abs=1e-10)
def test_posteriors_sum_to_one() -> None:
params, data = _make_tiny_problem()
fb = forward_backward(params, data)
gamma = np.exp(fb.log_gamma)
assert np.allclose(gamma.sum(axis=2), 1.0, atol=1e-9)
def test_xi_marginalizes_to_gamma() -> None:
params, data = _make_tiny_problem()
fb = forward_backward(params, data)
xi = np.exp(fb.log_xi)
gamma = np.exp(fb.log_gamma)
sum_over_j = xi.sum(axis=3) # (N, T, K)
assert np.allclose(sum_over_j, gamma[:, :-1, :], atol=1e-9)
def test_fb_handles_multi_household() -> None:
"""Forward-backward should run on a many-household, longer problem too."""
rng = np.random.default_rng(0)
n, t_plus_1, k, f_dim = 8, 12, 5, 6
inputs = rng.normal(size=(n, t_plus_1, f_dim)) * 0.3
departure = (rng.random((n, t_plus_1)) < 0.4).astype(np.float64)
displacement = rng.normal(size=(n, t_plus_1)) * 0.5
comm = rng.poisson(1.0, size=(n, t_plus_1)).astype(np.float64)
data = FitData(
inputs=inputs,
departure=departure,
displacement=displacement,
comm=comm,
true_states=None,
)
alpha = np.array(
[
[0.0, -1.0, -1e30, -1e30, -1e30],
[-2.0, 0.0, -2.5, -1e30, -1e30],
[-1e30, -1e30, 0.0, -2.0, -3.0],
[-1e30, -1e30, -1e30, 0.0, -2.0],
[-1e30, -1e30, -1e30, -1e30, 0.0],
],
dtype=np.float64,
)
alpha = np.where(alpha < -1e10, -np.inf, alpha)
beta = np.zeros((k, k, f_dim), dtype=np.float64)
beta[0, 1] = rng.normal(size=f_dim) * 0.2
beta[1, 2] = rng.normal(size=f_dim) * 0.2
beta[2, 3] = rng.normal(size=f_dim) * 0.2
beta[3, 4] = rng.normal(size=f_dim) * 0.2
emit = EmissionFitParams(
p_departure=np.array([0.05, 0.05, 0.1, 0.95, 0.05], dtype=np.float64),
mu_displacement=np.array([0.0, 0.0, 0.0, 1.0, 5.0], dtype=np.float64),
sigma_displacement=np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float64),
lambda_comm=np.array([0.5, 1.0, 2.0, 1.5, 0.5], dtype=np.float64),
)
params = FitParameters(
initial=InitialFitParams(logits=np.array([0.0, -1.0, -1.0, -1.0, -1.0])),
transitions=TransitionFitParams(alpha=alpha, beta=beta),
emissions=emit,
)
fb = forward_backward(params, data)
assert fb.log_gamma.shape == (n, t_plus_1, k)
assert fb.log_xi.shape == (n, t_plus_1 - 1, k, k)
gamma = np.exp(fb.log_gamma)
assert np.allclose(gamma.sum(axis=2), 1.0, atol=1e-9)