# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Recovery-flavored ``iohmm-evac report`` subcommands and dispatch.
Wired by :mod:`iohmm_evac.report.cli`; lives in its own module to keep
``report/cli.py`` under the project's 300-line per-file ceiling.
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from iohmm_evac.diagnostics.alignment import align_states, apply_permutation
from iohmm_evac.diagnostics.recovery import (
align_fit_to_truth,
parameter_recovery,
state_recovery_accuracy,
state_recovery_confusion,
)
from iohmm_evac.inference.cli import _build_dgp_truth_from_config
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.inference.fit_params import K, dgp_truth_to_fit_init
from iohmm_evac.inference.io import read_fit_bundle
from iohmm_evac.report.constants import DEFAULT_DPI
from iohmm_evac.report.loader import load_bundle
from iohmm_evac.report.recovery_plots import (
plot_log_likelihood_trace,
plot_parameter_recovery,
plot_state_recovery_confusion,
)
__all__ = [
"FIT_SUBCOMMANDS",
"add_fit_report_subparsers",
"run_fit_report",
]
FIT_SUBCOMMANDS: tuple[str, ...] = (
"fit-summary",
"recovery-confusion",
"parameter-recovery",
"ll-trace",
)
def add_fit_report_subparsers(
actions: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
"""Register the four fit-related ``report`` subcommands on ``actions``."""
p_summary = actions.add_parser(
"fit-summary", help="Print a fit-bundle summary (LLs, iterations, convergence)."
)
p_summary.add_argument("--fit", type=Path, required=True, help="Fit bundle directory.")
p_conf = actions.add_parser(
"recovery-confusion", help="Confusion-matrix heatmap (true x fit-aligned)."
)
p_conf.add_argument("--fit", type=Path, required=True, help="Fit bundle directory.")
p_conf.add_argument(
"--truth", type=Path, required=True, help="Truth simulation bundle Parquet path."
)
p_conf.add_argument("--output", type=Path, default=None, help="PNG output path.")
p_param = actions.add_parser(
"parameter-recovery", help="Scatter of true vs estimated parameter values."
)
p_param.add_argument("--fit", type=Path, required=True, help="Fit bundle directory.")
p_param.add_argument(
"--truth", type=Path, required=True, help="Truth simulation bundle Parquet path."
)
p_param.add_argument("--output", type=Path, default=None, help="PNG output path.")
p_ll = actions.add_parser(
"ll-trace", help="Log-likelihood trace per restart over EM iterations."
)
p_ll.add_argument("--fit", type=Path, required=True, help="Fit bundle directory.")
p_ll.add_argument("--output", type=Path, default=None, help="PNG output path.")
def _save_or_show(fig: Figure, output: Path | None, default_path: Path) -> Path:
target = output if output is not None else default_path
target.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(target, dpi=DEFAULT_DPI, bbox_inches="tight")
plt.close(fig)
return target
def _print_fit_summary(args: argparse.Namespace) -> int:
fit_bundle = read_fit_bundle(args.fit)
print(f"fit dir: {args.fit}", file=sys.stdout)
print(f"K (states): {K}", file=sys.stdout)
print(f"feature_names: {list(fit_bundle.params.feature_names)}", file=sys.stdout)
print(f"restarts: {len(fit_bundle.log_likelihood_traces)}", file=sys.stdout)
print(f"best restart: #{fit_bundle.best_index}", file=sys.stdout)
if fit_bundle.final_log_likelihoods:
print(
"final LLs: " + ", ".join(f"{x:.3f}" for x in fit_bundle.final_log_likelihoods),
file=sys.stdout,
)
if fit_bundle.iterations_per_restart:
print(
"iterations: " + ", ".join(str(x) for x in fit_bundle.iterations_per_restart),
file=sys.stdout,
)
if fit_bundle.converged_per_restart:
print(
"converged: " + ", ".join(str(x) for x in fit_bundle.converged_per_restart),
file=sys.stdout,
)
return 0
def _draw_confusion(args: argparse.Namespace) -> int:
bundle = load_bundle(args.truth)
fit_bundle = read_fit_bundle(args.fit)
data = bundle_to_fit_data(bundle)
if data.true_states is None:
msg = "Truth bundle has no state path; cannot draw confusion."
raise ValueError(msg)
perm = align_states(data.true_states, fit_bundle.posterior_states, K)
aligned = apply_permutation(fit_bundle.posterior_states, perm)
confusion = state_recovery_confusion(data.true_states, aligned)
accuracy = state_recovery_accuracy(data.true_states, aligned)
fig, ax = plt.subplots(figsize=(5, 5))
plot_state_recovery_confusion(confusion, ax=ax)
ax.set_title(f"State recovery (accuracy={accuracy:.3f})")
default = args.fit / "recovery_confusion.png"
_save_or_show(fig, args.output, default)
return 0
def _draw_parameter_recovery(args: argparse.Namespace) -> int:
bundle = load_bundle(args.truth)
fit_bundle = read_fit_bundle(args.fit)
data = bundle_to_fit_data(bundle)
if data.true_states is None:
msg = "Truth bundle has no state path; cannot draw recovery."
raise ValueError(msg)
perm = align_states(data.true_states, fit_bundle.posterior_states, K)
truth_cfg = _build_dgp_truth_from_config(bundle.config)
truth_params = dgp_truth_to_fit_init(
truth_cfg.transitions, truth_cfg.emissions, truth_cfg.population
)
aligned_fit = align_fit_to_truth(fit_bundle.params, perm)
report = parameter_recovery(truth_params, aligned_fit)
fig, ax = plt.subplots(figsize=(7, 7))
plot_parameter_recovery(report, ax=ax)
default = args.fit / "parameter_recovery.png"
_save_or_show(fig, args.output, default)
return 0
def _draw_ll_trace(args: argparse.Namespace) -> int:
fit_bundle = read_fit_bundle(args.fit)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
fig.suptitle("EM log-likelihood traces (units: total LL across all (i,t))", fontsize=10)
traces = [list(t) for t in fit_bundle.log_likelihood_traces]
plot_log_likelihood_trace(traces, ax=list(axes), best_index=fit_bundle.best_index)
default = args.fit / "ll_trace.png"
_save_or_show(fig, args.output, default)
return 0
def run_fit_report(args: argparse.Namespace) -> int:
"""Dispatch one of the four fit-flavored ``report`` subcommands."""
action = args.action
if action == "fit-summary":
return _print_fit_summary(args)
if action == "recovery-confusion":
return _draw_confusion(args)
if action == "parameter-recovery":
return _draw_parameter_recovery(args)
if action == "ll-trace":
return _draw_ll_trace(args)
msg = f"Unknown fit-report action: {action!r}"
raise ValueError(msg)