tests/test_forward_backward.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Numerical-correctness tests for the log-space forward-backward."""

from __future__ import annotations

import itertools

import numpy as np
import pytest

from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import (
    EmissionFitParams,
    FitParameters,
    InitialFitParams,
    TransitionFitParams,
)
from iohmm_evac.inference.forward_backward import (
    emission_log_prob,
    forward_backward,
    log_transition_matrix,
)


def _make_tiny_problem() -> tuple[FitParameters, FitData]:
    """A K=2, T=2 (so T+1=3 timesteps), N=1 problem for brute-force checking.

    Two states, two-dim feature vector. Emissions: Bernoulli + Gaussian +
    Poisson with deliberately distinct per-state parameters so the
    posteriors aren't trivial.
    """
    alpha = np.array([[0.0, -0.5], [0.3, 0.0]], dtype=np.float64)
    beta = np.array(
        [
            [[0.0, 0.0], [0.4, -0.1]],
            [[-0.2, 0.3], [0.0, 0.0]],
        ],
        dtype=np.float64,
    )
    initial_logits = np.array([0.4, -0.4], dtype=np.float64)
    emit = EmissionFitParams(
        p_departure=np.array([0.2, 0.8], dtype=np.float64),
        mu_displacement=np.array([1.0, 4.0], dtype=np.float64),
        sigma_displacement=np.array([1.0, 1.5], dtype=np.float64),
        lambda_comm=np.array([1.0, 3.0], dtype=np.float64),
    )
    params = FitParameters(
        initial=InitialFitParams(logits=initial_logits),
        transitions=TransitionFitParams(alpha=alpha, beta=beta),
        emissions=emit,
        feature_names=("x", "y"),
    )

    inputs = np.array([[[0.5, 0.0], [-0.4, 0.7], [0.1, -0.2]]], dtype=np.float64)
    departure = np.array([[1.0, 0.0, 1.0]], dtype=np.float64)
    displacement = np.array([[1.2, 3.6, 2.5]], dtype=np.float64)
    comm = np.array([[1.0, 3.0, 2.0]], dtype=np.float64)
    data = FitData(
        inputs=inputs,
        departure=departure,
        displacement=displacement,
        comm=comm,
        true_states=None,
    )
    # _bf_likelihood uses K=2 implicitly; the FitData K still equals the
    # global ``K`` used elsewhere, but we can drop it via a custom test
    # that ignores the K mismatch — see _brute_force.
    return params, data


def _brute_force_log_likelihood(params: FitParameters, data: FitData, k: int) -> float:
    """Sum joint likelihoods over every state path for the tiny problem."""
    log_a = log_transition_matrix(data.inputs, params.transitions.alpha, params.transitions.beta)
    log_b = emission_log_prob(data, params)
    log_initial = np.log(params.initial.probs())

    n, t_plus_1, _ = log_b.shape
    assert n == 1
    total = -np.inf
    for path in itertools.product(range(k), repeat=t_plus_1):
        ll = log_initial[path[0]] + log_b[0, 0, path[0]]
        for t in range(1, t_plus_1):
            ll += log_a[0, t, path[t - 1], path[t]] + log_b[0, t, path[t]]
        total = float(np.logaddexp(total, ll))
    return total


def test_brute_force_equivalence() -> None:
    params, data = _make_tiny_problem()
    fb = forward_backward(params, data)
    bf_ll = _brute_force_log_likelihood(params, data, k=2)
    fb_ll = float(fb.log_likelihood[0])
    assert fb_ll == pytest.approx(bf_ll, abs=1e-10)


def test_posteriors_sum_to_one() -> None:
    params, data = _make_tiny_problem()
    fb = forward_backward(params, data)
    gamma = np.exp(fb.log_gamma)
    assert np.allclose(gamma.sum(axis=2), 1.0, atol=1e-9)


def test_xi_marginalizes_to_gamma() -> None:
    params, data = _make_tiny_problem()
    fb = forward_backward(params, data)
    xi = np.exp(fb.log_xi)
    gamma = np.exp(fb.log_gamma)
    sum_over_j = xi.sum(axis=3)  # (N, T, K)
    assert np.allclose(sum_over_j, gamma[:, :-1, :], atol=1e-9)


def test_fb_handles_multi_household() -> None:
    """Forward-backward should run on a many-household, longer problem too."""
    rng = np.random.default_rng(0)
    n, t_plus_1, k, f_dim = 8, 12, 5, 6
    inputs = rng.normal(size=(n, t_plus_1, f_dim)) * 0.3
    departure = (rng.random((n, t_plus_1)) < 0.4).astype(np.float64)
    displacement = rng.normal(size=(n, t_plus_1)) * 0.5
    comm = rng.poisson(1.0, size=(n, t_plus_1)).astype(np.float64)
    data = FitData(
        inputs=inputs,
        departure=departure,
        displacement=displacement,
        comm=comm,
        true_states=None,
    )

    alpha = np.array(
        [
            [0.0, -1.0, -1e30, -1e30, -1e30],
            [-2.0, 0.0, -2.5, -1e30, -1e30],
            [-1e30, -1e30, 0.0, -2.0, -3.0],
            [-1e30, -1e30, -1e30, 0.0, -2.0],
            [-1e30, -1e30, -1e30, -1e30, 0.0],
        ],
        dtype=np.float64,
    )
    alpha = np.where(alpha < -1e10, -np.inf, alpha)
    beta = np.zeros((k, k, f_dim), dtype=np.float64)
    beta[0, 1] = rng.normal(size=f_dim) * 0.2
    beta[1, 2] = rng.normal(size=f_dim) * 0.2
    beta[2, 3] = rng.normal(size=f_dim) * 0.2
    beta[3, 4] = rng.normal(size=f_dim) * 0.2

    emit = EmissionFitParams(
        p_departure=np.array([0.05, 0.05, 0.1, 0.95, 0.05], dtype=np.float64),
        mu_displacement=np.array([0.0, 0.0, 0.0, 1.0, 5.0], dtype=np.float64),
        sigma_displacement=np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float64),
        lambda_comm=np.array([0.5, 1.0, 2.0, 1.5, 0.5], dtype=np.float64),
    )
    params = FitParameters(
        initial=InitialFitParams(logits=np.array([0.0, -1.0, -1.0, -1.0, -1.0])),
        transitions=TransitionFitParams(alpha=alpha, beta=beta),
        emissions=emit,
    )
    fb = forward_backward(params, data)
    assert fb.log_gamma.shape == (n, t_plus_1, k)
    assert fb.log_xi.shape == (n, t_plus_1 - 1, k, k)
    gamma = np.exp(fb.log_gamma)
    assert np.allclose(gamma.sum(axis=2), 1.0, atol=1e-9)