# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""EM-loop tests: monotonicity and convergence on a synthetic problem."""
from __future__ import annotations
import numpy as np
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.em import EMConfig, run_em
from iohmm_evac.inference.fit_params import (
EmissionFitParams,
FitParameters,
InitialFitParams,
TransitionFitParams,
)
from iohmm_evac.inference.forward_backward import forward_backward
from tests._clean_dgp import generate
def _make_clean_truth(rng: np.random.Generator) -> tuple[FitParameters, FitData]:
"""A small K=3, T=20, N=50 problem generated from a clean DGP."""
k = 3
f = 2
n = 50
t_plus_1 = 21
alpha = np.array(
[
[0.0, -1.5, -3.0],
[-3.0, 0.0, -1.5],
[-3.0, -3.0, 0.0],
],
dtype=np.float64,
)
beta = np.zeros((k, k, f))
beta[0, 1] = np.array([1.5, 0.0])
beta[1, 2] = np.array([0.0, 1.5])
emit = EmissionFitParams(
p_departure=np.array([0.05, 0.5, 0.95]),
mu_displacement=np.array([0.0, 1.0, 5.0]),
sigma_displacement=np.array([0.5, 0.5, 0.5]),
lambda_comm=np.array([0.5, 1.5, 3.0]),
)
truth = FitParameters(
initial=InitialFitParams(logits=np.array([0.0, -2.0, -2.0])),
transitions=TransitionFitParams(alpha=alpha, beta=beta),
emissions=emit,
)
inputs = np.zeros((n, t_plus_1, f), dtype=np.float64)
# u[:, t, 0] ramps with time; u[:, t, 1] turns on at t >= T/2.
inputs[:, :, 0] = np.linspace(0.0, 1.0, t_plus_1)[None, :]
inputs[:, t_plus_1 // 2 :, 1] = 1.0
sample = generate(truth, inputs, rng)
data = FitData(
inputs=inputs,
departure=sample.departure,
displacement=sample.displacement,
comm=sample.comm,
true_states=sample.states,
)
return truth, data
def _perturb(truth: FitParameters, rng: np.random.Generator) -> FitParameters:
"""Move ``truth`` slightly so EM has to actually do work to recover it."""
alpha = truth.transitions.alpha.copy()
beta = truth.transitions.beta.copy()
finite_mask = np.isfinite(alpha) & ~np.eye(alpha.shape[0], dtype=bool)
alpha = np.where(finite_mask, alpha + 0.3 * rng.standard_normal(alpha.shape), alpha)
beta = beta + 0.2 * rng.standard_normal(beta.shape)
init_logits = truth.initial.logits + 0.2 * rng.standard_normal(truth.initial.logits.shape)
p = np.clip(truth.emissions.p_departure + 0.05 * rng.standard_normal(3), 1e-3, 1 - 1e-3)
mu = truth.emissions.mu_displacement + 0.3 * rng.standard_normal(3)
sigma = np.maximum(truth.emissions.sigma_displacement + 0.2 * rng.standard_normal(3), 0.1)
lam = np.maximum(truth.emissions.lambda_comm + 0.2 * rng.standard_normal(3), 0.1)
emit = EmissionFitParams(
p_departure=p,
mu_displacement=mu,
sigma_displacement=sigma,
lambda_comm=lam,
)
return FitParameters(
initial=InitialFitParams(logits=init_logits),
transitions=TransitionFitParams(alpha=alpha, beta=beta),
emissions=emit,
feature_names=truth.feature_names,
)
def test_em_log_likelihood_is_non_decreasing() -> None:
rng = np.random.default_rng(7)
truth, data = _make_clean_truth(rng)
init = _perturb(truth, rng)
em = run_em(init, data, EMConfig(max_iter=20, tol=1e-8))
trace = em.log_likelihood_trace
assert len(trace) >= 2
diffs = np.diff(np.array(trace))
assert (diffs > -1e-9).all(), f"non-monotone trace: {trace}"
def test_em_converges_close_to_truth_log_likelihood() -> None:
rng = np.random.default_rng(11)
truth, data = _make_clean_truth(rng)
init = _perturb(truth, rng)
em = run_em(init, data, EMConfig(max_iter=100, tol=1e-8))
truth_ll = float(forward_backward(truth, data).log_likelihood.sum())
fit_ll = em.final_log_likelihood
n_obs = data.n * (data.t_total + 1)
per_obs_gap = (truth_ll - fit_ll) / n_obs
# Relax slightly from the design target of ``1e-3`` since the optimization
# is finite-N and may exceed the truth's likelihood by overfitting.
assert per_obs_gap < 5e-2, (
f"per-observation LL gap {per_obs_gap:.4e} too large "
f"(truth={truth_ll:.3f}, fit={fit_ll:.3f}, n_obs={n_obs})"
)