# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""``iohmm-evac diagnose recovery``: align a fit to truth and write a TOML report."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import numpy as np
import tomli_w
from iohmm_evac.diagnostics.alignment import align_states, apply_permutation
from iohmm_evac.diagnostics.recovery import (
align_fit_to_truth,
parameter_recovery,
state_recovery_accuracy,
state_recovery_confusion,
)
from iohmm_evac.inference.cli import _build_dgp_truth_from_config
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.inference.fit_params import K, dgp_truth_to_fit_init
from iohmm_evac.inference.io import read_fit_bundle
from iohmm_evac.report.loader import load_bundle
__all__ = ["add_diagnose_subparser", "run_diagnose"]
def add_diagnose_subparser(
subparsers: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
"""Register the ``diagnose`` subcommand on the top-level CLI parser."""
p = subparsers.add_parser("diagnose", help="Recovery diagnostics for a saved fit.")
actions = p.add_subparsers(dest="action", required=True)
p_recovery = actions.add_parser(
"recovery", help="Compute state-and-parameter recovery vs truth."
)
p_recovery.add_argument("--fit", type=Path, required=True, help="Fit bundle directory.")
p_recovery.add_argument(
"--truth", type=Path, required=True, help="Original simulation bundle Parquet path."
)
p_recovery.add_argument(
"--output",
type=Path,
default=None,
help="Output recovery.toml path. Default: <fit-dir>/recovery.toml.",
)
def _to_serializable(arr: np.ndarray) -> list[list[float]] | list[float]:
if arr.ndim == 1:
return [float(x) for x in arr]
return [[float(x) for x in row] for row in arr]
def run_diagnose(args: argparse.Namespace) -> int:
"""Execute ``diagnose recovery``."""
if args.action != "recovery":
msg = f"Unknown diagnose action: {args.action!r}"
raise ValueError(msg)
bundle = load_bundle(args.truth)
fit_bundle = read_fit_bundle(args.fit)
data = bundle_to_fit_data(bundle)
if data.true_states is None:
msg = "Truth bundle has no state path; cannot diagnose recovery."
raise ValueError(msg)
perm = align_states(data.true_states, fit_bundle.posterior_states, K)
aligned_states = apply_permutation(fit_bundle.posterior_states, perm)
confusion = state_recovery_confusion(data.true_states, aligned_states)
accuracy = state_recovery_accuracy(data.true_states, aligned_states)
truth_cfg = _build_dgp_truth_from_config(bundle.config)
truth_params = dgp_truth_to_fit_init(
truth_cfg.transitions, truth_cfg.emissions, truth_cfg.population
)
aligned_fit = align_fit_to_truth(fit_bundle.params, perm)
recovery_report = parameter_recovery(truth_params, aligned_fit)
output_path = args.output if args.output is not None else (args.fit / "recovery.toml")
payload = {
"state_recovery": {
"accuracy": float(accuracy),
"confusion": _to_serializable(confusion),
"permutation": [int(x) for x in perm],
},
"transition_rmse": {
"alpha": float(recovery_report.transition_alpha_rmse),
"beta": float(recovery_report.transition_beta_rmse),
},
"emission_rmse": {
"p_departure": float(recovery_report.emission_p_rmse),
"mu_displacement": float(recovery_report.emission_mu_rmse),
"sigma_displacement": float(recovery_report.emission_sigma_rmse),
"lambda_comm": float(recovery_report.emission_lambda_rmse),
},
}
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_bytes(tomli_w.dumps(payload).encode("utf-8"))
print(f"recovery: {output_path}", file=sys.stderr)
print(f"state recovery accuracy: {accuracy:.4f}", file=sys.stderr)
return 0