src/iohmm_evac/diagnostics/cli.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""``iohmm-evac diagnose recovery``: align a fit to truth and write a TOML report."""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

import numpy as np
import tomli_w

from iohmm_evac.diagnostics.alignment import align_states, apply_permutation
from iohmm_evac.diagnostics.recovery import (
    align_fit_to_truth,
    parameter_recovery,
    state_recovery_accuracy,
    state_recovery_confusion,
)
from iohmm_evac.inference.cli import _build_dgp_truth_from_config
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.inference.fit_params import K, dgp_truth_to_fit_init
from iohmm_evac.inference.io import read_fit_bundle
from iohmm_evac.report.loader import load_bundle

__all__ = ["add_diagnose_subparser", "run_diagnose"]


def add_diagnose_subparser(
    subparsers: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
    """Register the ``diagnose`` subcommand on the top-level CLI parser."""
    p = subparsers.add_parser("diagnose", help="Recovery diagnostics for a saved fit.")
    actions = p.add_subparsers(dest="action", required=True)

    p_recovery = actions.add_parser(
        "recovery", help="Compute state-and-parameter recovery vs truth."
    )
    p_recovery.add_argument("--fit", type=Path, required=True, help="Fit bundle directory.")
    p_recovery.add_argument(
        "--truth", type=Path, required=True, help="Original simulation bundle Parquet path."
    )
    p_recovery.add_argument(
        "--output",
        type=Path,
        default=None,
        help="Output recovery.toml path. Default: <fit-dir>/recovery.toml.",
    )


def _to_serializable(arr: np.ndarray) -> list[list[float]] | list[float]:
    if arr.ndim == 1:
        return [float(x) for x in arr]
    return [[float(x) for x in row] for row in arr]


def run_diagnose(args: argparse.Namespace) -> int:
    """Execute ``diagnose recovery``."""
    if args.action != "recovery":
        msg = f"Unknown diagnose action: {args.action!r}"
        raise ValueError(msg)
    bundle = load_bundle(args.truth)
    fit_bundle = read_fit_bundle(args.fit)
    data = bundle_to_fit_data(bundle)
    if data.true_states is None:
        msg = "Truth bundle has no state path; cannot diagnose recovery."
        raise ValueError(msg)

    perm = align_states(data.true_states, fit_bundle.posterior_states, K)
    aligned_states = apply_permutation(fit_bundle.posterior_states, perm)
    confusion = state_recovery_confusion(data.true_states, aligned_states)
    accuracy = state_recovery_accuracy(data.true_states, aligned_states)

    truth_cfg = _build_dgp_truth_from_config(bundle.config)
    truth_params = dgp_truth_to_fit_init(
        truth_cfg.transitions, truth_cfg.emissions, truth_cfg.population
    )
    aligned_fit = align_fit_to_truth(fit_bundle.params, perm)
    recovery_report = parameter_recovery(truth_params, aligned_fit)

    output_path = args.output if args.output is not None else (args.fit / "recovery.toml")
    payload = {
        "state_recovery": {
            "accuracy": float(accuracy),
            "confusion": _to_serializable(confusion),
            "permutation": [int(x) for x in perm],
        },
        "transition_rmse": {
            "alpha": float(recovery_report.transition_alpha_rmse),
            "beta": float(recovery_report.transition_beta_rmse),
        },
        "emission_rmse": {
            "p_departure": float(recovery_report.emission_p_rmse),
            "mu_displacement": float(recovery_report.emission_mu_rmse),
            "sigma_displacement": float(recovery_report.emission_sigma_rmse),
            "lambda_comm": float(recovery_report.emission_lambda_rmse),
        },
    }
    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_bytes(tomli_w.dumps(payload).encode("utf-8"))
    print(f"recovery: {output_path}", file=sys.stderr)
    print(f"state recovery accuracy: {accuracy:.4f}", file=sys.stderr)
    return 0