tests/test_m_step.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Numerical-gradient and closed-form sanity checks for the M-step."""

from __future__ import annotations

import numpy as np
import pytest
from scipy.optimize import check_grad

from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import (
    EmissionFitParams,
    FitParameters,
    InitialFitParams,
    TransitionFitParams,
)
from iohmm_evac.inference.forward_backward import forward_backward
from iohmm_evac.inference.m_step import (
    transition_neg_q_and_grad,
    update_emissions,
    update_initial,
)


def _random_problem(seed: int = 0) -> tuple[FitParameters, FitData]:
    """A small K=3, T=10, N=20 problem with random ξ/γ for the gradient check."""
    rng = np.random.default_rng(seed)
    k = 3
    f = 4
    n = 20
    t_plus_1 = 11
    inputs = rng.normal(size=(n, t_plus_1, f)) * 0.5

    alpha = np.array(
        [
            [0.0, -1.0, -2.0],
            [-1.5, 0.0, -1.0],
            [-2.0, -2.0, 0.0],
        ],
        dtype=np.float64,
    )
    beta = rng.normal(size=(k, k, f)) * 0.2
    np.fill_diagonal(alpha, 0.0)
    for kk in range(k):
        beta[kk, kk] = 0.0
    emit = EmissionFitParams(
        p_departure=np.array([0.1, 0.5, 0.9]),
        mu_displacement=np.array([0.0, 1.0, 4.0]),
        sigma_displacement=np.array([1.0, 1.0, 1.0]),
        lambda_comm=np.array([0.5, 1.5, 3.0]),
    )
    params = FitParameters(
        initial=InitialFitParams(logits=np.array([0.0, -0.5, -1.0])),
        transitions=TransitionFitParams(alpha=alpha, beta=beta),
        emissions=emit,
    )

    departure = (rng.random((n, t_plus_1)) < 0.4).astype(np.float64)
    displacement = rng.normal(size=(n, t_plus_1)) * 0.5 + 1.0
    comm = rng.poisson(1.5, size=(n, t_plus_1)).astype(np.float64)
    data = FitData(
        inputs=inputs,
        departure=departure,
        displacement=displacement,
        comm=comm,
        true_states=None,
    )
    return params, data


def test_transition_gradient_numerical_check() -> None:
    params, data = _random_problem(seed=42)
    fb = forward_backward(params, data)
    xi = np.exp(fb.log_xi)
    gamma = np.exp(fb.log_gamma)

    n, t_total, k, _ = xi.shape
    f = data.inputs.shape[2]
    inputs_steps = data.inputs[:, 1:, :].reshape(n * t_total, f)
    xi_flat = xi.reshape(n * t_total, k, k)
    gamma_flat = gamma[:, :-1, :].reshape(n * t_total, k)

    rng = np.random.default_rng(0)
    for origin_k in range(k):
        learnable_j = np.array([j != origin_k for j in range(k)], dtype=bool)
        n_dest = int(learnable_j.sum())
        x0 = rng.normal(size=n_dest + n_dest * f) * 0.5

        def fun(x: np.ndarray, ok: int = origin_k, lj: np.ndarray = learnable_j) -> float:
            val, _grad = transition_neg_q_and_grad(
                x,
                origin_k=ok,
                learnable_j=lj,
                xi_origin=xi_flat[:, ok, :],
                gamma_origin=gamma_flat[:, ok],
                inputs_steps=inputs_steps,
            )
            return val

        def grad(x: np.ndarray, ok: int = origin_k, lj: np.ndarray = learnable_j) -> np.ndarray:
            _val, g = transition_neg_q_and_grad(
                x,
                origin_k=ok,
                learnable_j=lj,
                xi_origin=xi_flat[:, ok, :],
                gamma_origin=gamma_flat[:, ok],
                inputs_steps=inputs_steps,
            )
            return g

        err = check_grad(fun, grad, x0)
        # Tolerance: 1e-5 on a tiny problem.
        assert err < 1e-4, f"origin {origin_k}: numerical-gradient error {err:.3e}"


def test_emission_closed_form_matches_brute_force() -> None:
    """With known posteriors, weighted-MLE emission updates match a hand-rolled MLE."""
    params, data = _random_problem(seed=1)
    fb = forward_backward(params, data)
    new_emit = update_emissions(fb.log_gamma, data, sigma_floor=1e-6)

    gamma = np.exp(fb.log_gamma)
    weights = gamma.sum(axis=(0, 1))  # (K,)
    expected_p = (gamma * data.departure[:, :, None]).sum(axis=(0, 1)) / weights
    expected_mu = (gamma * data.displacement[:, :, None]).sum(axis=(0, 1)) / weights
    expected_lam = (gamma * data.comm[:, :, None]).sum(axis=(0, 1)) / weights

    assert new_emit.p_departure == pytest.approx(np.clip(expected_p, 1e-6, 1 - 1e-6), abs=1e-9)
    assert new_emit.mu_displacement == pytest.approx(expected_mu, abs=1e-9)
    assert new_emit.lambda_comm == pytest.approx(np.maximum(expected_lam, 1e-6), abs=1e-9)


def test_initial_closed_form_normalizes() -> None:
    params, data = _random_problem(seed=2)
    fb = forward_backward(params, data)
    new_init = update_initial(fb.log_gamma)
    probs = new_init.probs()
    assert probs.sum() == pytest.approx(1.0, abs=1e-9)
    assert (probs >= 0).all()