# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Direct (in-process) coverage of inference and report-recovery modules."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import pytest
from iohmm_evac.cli import main
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.diagnostics.alignment import align_states
from iohmm_evac.diagnostics.recovery import (
align_fit_to_truth,
parameter_recovery,
state_recovery_accuracy,
state_recovery_confusion,
)
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.inference.em import EMConfig
from iohmm_evac.inference.fit import fit
from iohmm_evac.inference.fit_params import dgp_truth_to_fit_init
from iohmm_evac.inference.initialization import (
from_dgp_truth,
kmeans_init,
random_initialization,
)
from iohmm_evac.inference.io import read_fit_bundle, write_fit_bundle
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig
from iohmm_evac.report.loader import load_bundle
from iohmm_evac.report.recovery_plots import (
plot_log_likelihood_trace,
plot_parameter_recovery,
plot_state_recovery_confusion,
)
@pytest.fixture
def tiny_bundle_path(tmp_path: Path) -> Path:
config = SimulationConfig(n_households=80, n_hours=18, seed=0)
rng = np.random.default_rng(config.seed)
result = simulate(config, rng)
out = tmp_path / "tiny.parquet"
write_results(result, out)
return out
def test_random_initialization_shapes() -> None:
rng = np.random.default_rng(0)
params = random_initialization(rng)
assert params.transitions.alpha.shape == (5, 5)
assert params.transitions.beta.shape[0] == 5
assert params.emissions.p_departure.shape == (5,)
def test_kmeans_init_returns_valid_params(tiny_bundle_path: Path) -> None:
bundle = load_bundle(tiny_bundle_path)
data = bundle_to_fit_data(bundle)
rng = np.random.default_rng(0)
params = kmeans_init(data, rng)
assert np.isfinite(params.emissions.mu_displacement).all()
def test_from_dgp_truth_matches_dgp_truth_adapter() -> None:
cfg = SimulationConfig()
a = from_dgp_truth(cfg.transitions, cfg.emissions)
b = dgp_truth_to_fit_init(cfg.transitions, cfg.emissions)
np.testing.assert_array_equal(a.transitions.alpha, b.transitions.alpha)
def test_fit_io_round_trip(tmp_path: Path, tiny_bundle_path: Path) -> None:
bundle = load_bundle(tiny_bundle_path)
data = bundle_to_fit_data(bundle)
rng = np.random.default_rng(0)
cfg_truth = SimulationConfig()
truth_init = dgp_truth_to_fit_init(cfg_truth.transitions, cfg_truth.emissions)
result = fit(
data,
n_restarts=1,
em_config=EMConfig(max_iter=2, tol=1e-3),
init="truth",
rng=rng,
truth_init=truth_init,
)
posterior = np.zeros((data.n, data.t_total + 1), dtype=np.int64)
fit_dir = tmp_path / "fit"
paths = write_fit_bundle(result, posterior, fit_dir)
for label, p in paths.items():
assert p.exists(), label
fit_bundle = read_fit_bundle(fit_dir)
assert fit_bundle.params.transitions.alpha.shape == (5, 5)
np.testing.assert_array_equal(fit_bundle.posterior_states, posterior)
def test_fit_random_init_runs(tiny_bundle_path: Path) -> None:
bundle = load_bundle(tiny_bundle_path)
data = bundle_to_fit_data(bundle)
result = fit(
data,
n_restarts=2,
em_config=EMConfig(max_iter=2, tol=1e-3),
init="random",
rng=np.random.default_rng(0),
)
assert len(result.all_runs) == 2
assert result.best_index in {0, 1}
def test_fit_kmeans_init_runs(tiny_bundle_path: Path) -> None:
bundle = load_bundle(tiny_bundle_path)
data = bundle_to_fit_data(bundle)
result = fit(
data,
n_restarts=1,
em_config=EMConfig(max_iter=2, tol=1e-3),
init="kmeans",
rng=np.random.default_rng(0),
)
assert result.best.iterations >= 1
def test_fit_truth_init_requires_truth(tiny_bundle_path: Path) -> None:
bundle = load_bundle(tiny_bundle_path)
data = bundle_to_fit_data(bundle)
with pytest.raises(ValueError, match="truth"):
fit(
data,
n_restarts=1,
em_config=EMConfig(max_iter=1),
init="truth",
rng=np.random.default_rng(0),
)
def test_fit_unknown_init_strategy(tiny_bundle_path: Path) -> None:
bundle = load_bundle(tiny_bundle_path)
data = bundle_to_fit_data(bundle)
with pytest.raises(ValueError, match="Unknown init strategy"):
fit(
data,
n_restarts=1,
em_config=EMConfig(max_iter=1),
init="bogus",
rng=np.random.default_rng(0),
)
def test_fit_zero_restarts_rejected(tiny_bundle_path: Path) -> None:
bundle = load_bundle(tiny_bundle_path)
data = bundle_to_fit_data(bundle)
with pytest.raises(ValueError, match="n_restarts"):
fit(data, n_restarts=0)
def test_recovery_plots_render(tmp_path: Path, tiny_bundle_path: Path) -> None:
import matplotlib.pyplot as plt
bundle = load_bundle(tiny_bundle_path)
data = bundle_to_fit_data(bundle)
cfg_truth = SimulationConfig()
truth_init = dgp_truth_to_fit_init(cfg_truth.transitions, cfg_truth.emissions)
result = fit(
data,
n_restarts=1,
em_config=EMConfig(max_iter=2, tol=1e-3),
init="truth",
rng=np.random.default_rng(0),
truth_init=truth_init,
)
from iohmm_evac.diagnostics.decoding import viterbi
fit_path = viterbi(result.best.params, data)
assert data.true_states is not None
perm = align_states(data.true_states, fit_path, k=5)
confusion = state_recovery_confusion(data.true_states, perm[fit_path])
accuracy = state_recovery_accuracy(data.true_states, perm[fit_path])
aligned = align_fit_to_truth(result.best.params, perm)
report = parameter_recovery(truth_init, aligned)
fig, ax = plt.subplots()
plot_state_recovery_confusion(confusion, ax=ax)
plt.close(fig)
fig, ax = plt.subplots()
plot_parameter_recovery(report, ax=ax)
plt.close(fig)
fig, axes = plt.subplots(1, 2)
plot_log_likelihood_trace(
[list(r.log_likelihood_trace) for r in result.all_runs],
ax=list(axes),
best_index=result.best_index,
)
plt.close(fig)
assert 0.0 <= accuracy <= 1.0
def test_cli_fit_diagnose_inproc(tmp_path: Path, tiny_bundle_path: Path) -> None:
fit_dir = tmp_path / "fit"
rc = main(
[
"fit",
"--input",
str(tiny_bundle_path),
"--output",
str(fit_dir),
"--init",
"truth",
"--max-iter",
"2",
"--quiet",
]
)
assert rc == 0
assert (fit_dir / "theta.toml").exists()
rc = main(
[
"diagnose",
"recovery",
"--fit",
str(fit_dir),
"--truth",
str(tiny_bundle_path),
]
)
assert rc == 0
assert (fit_dir / "recovery.toml").exists()
def test_cli_report_recovery_inproc(tmp_path: Path, tiny_bundle_path: Path) -> None:
fit_dir = tmp_path / "fit"
main(
[
"fit",
"--input",
str(tiny_bundle_path),
"--output",
str(fit_dir),
"--init",
"truth",
"--max-iter",
"2",
"--quiet",
]
)
confusion_png = tmp_path / "c.png"
rc = main(
[
"report",
"recovery-confusion",
"--fit",
str(fit_dir),
"--truth",
str(tiny_bundle_path),
"--output",
str(confusion_png),
]
)
assert rc == 0
assert confusion_png.exists()
param_png = tmp_path / "p.png"
rc = main(
[
"report",
"parameter-recovery",
"--fit",
str(fit_dir),
"--truth",
str(tiny_bundle_path),
"--output",
str(param_png),
]
)
assert rc == 0
assert param_png.exists()
ll_png = tmp_path / "ll.png"
rc = main(["report", "ll-trace", "--fit", str(fit_dir), "--output", str(ll_png)])
assert rc == 0
rc = main(["report", "fit-summary", "--fit", str(fit_dir)])
assert rc == 0