src/iohmm_evac/report/recovery_cli.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Recovery-flavored ``iohmm-evac report`` subcommands and dispatch.

Wired by :mod:`iohmm_evac.report.cli`; lives in its own module to keep
``report/cli.py`` under the project's 300-line per-file ceiling.
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib.figure import Figure

from iohmm_evac.diagnostics.alignment import align_states, apply_permutation
from iohmm_evac.diagnostics.recovery import (
    align_fit_to_truth,
    parameter_recovery,
    state_recovery_accuracy,
    state_recovery_confusion,
)
from iohmm_evac.inference.cli import _build_dgp_truth_from_config
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.inference.fit_params import K, dgp_truth_to_fit_init
from iohmm_evac.inference.io import read_fit_bundle
from iohmm_evac.report.constants import DEFAULT_DPI
from iohmm_evac.report.loader import load_bundle
from iohmm_evac.report.recovery_plots import (
    plot_log_likelihood_trace,
    plot_parameter_recovery,
    plot_state_recovery_confusion,
)

__all__ = [
    "FIT_SUBCOMMANDS",
    "add_fit_report_subparsers",
    "run_fit_report",
]


FIT_SUBCOMMANDS: tuple[str, ...] = (
    "fit-summary",
    "recovery-confusion",
    "parameter-recovery",
    "ll-trace",
)


def add_fit_report_subparsers(
    actions: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
    """Register the four fit-related ``report`` subcommands on ``actions``."""
    p_summary = actions.add_parser(
        "fit-summary", help="Print a fit-bundle summary (LLs, iterations, convergence)."
    )
    p_summary.add_argument("--fit", type=Path, required=True, help="Fit bundle directory.")

    p_conf = actions.add_parser(
        "recovery-confusion", help="Confusion-matrix heatmap (true x fit-aligned)."
    )
    p_conf.add_argument("--fit", type=Path, required=True, help="Fit bundle directory.")
    p_conf.add_argument(
        "--truth", type=Path, required=True, help="Truth simulation bundle Parquet path."
    )
    p_conf.add_argument("--output", type=Path, default=None, help="PNG output path.")

    p_param = actions.add_parser(
        "parameter-recovery", help="Scatter of true vs estimated parameter values."
    )
    p_param.add_argument("--fit", type=Path, required=True, help="Fit bundle directory.")
    p_param.add_argument(
        "--truth", type=Path, required=True, help="Truth simulation bundle Parquet path."
    )
    p_param.add_argument("--output", type=Path, default=None, help="PNG output path.")

    p_ll = actions.add_parser(
        "ll-trace", help="Log-likelihood trace per restart over EM iterations."
    )
    p_ll.add_argument("--fit", type=Path, required=True, help="Fit bundle directory.")
    p_ll.add_argument("--output", type=Path, default=None, help="PNG output path.")


def _save_or_show(fig: Figure, output: Path | None, default_path: Path) -> Path:
    target = output if output is not None else default_path
    target.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(target, dpi=DEFAULT_DPI, bbox_inches="tight")
    plt.close(fig)
    return target


def _print_fit_summary(args: argparse.Namespace) -> int:
    fit_bundle = read_fit_bundle(args.fit)
    print(f"fit dir:           {args.fit}", file=sys.stdout)
    print(f"K (states):        {K}", file=sys.stdout)
    print(f"feature_names:     {list(fit_bundle.params.feature_names)}", file=sys.stdout)
    print(f"restarts:          {len(fit_bundle.log_likelihood_traces)}", file=sys.stdout)
    print(f"best restart:      #{fit_bundle.best_index}", file=sys.stdout)
    if fit_bundle.final_log_likelihoods:
        print(
            "final LLs:         " + ", ".join(f"{x:.3f}" for x in fit_bundle.final_log_likelihoods),
            file=sys.stdout,
        )
    if fit_bundle.iterations_per_restart:
        print(
            "iterations:        " + ", ".join(str(x) for x in fit_bundle.iterations_per_restart),
            file=sys.stdout,
        )
    if fit_bundle.converged_per_restart:
        print(
            "converged:         " + ", ".join(str(x) for x in fit_bundle.converged_per_restart),
            file=sys.stdout,
        )
    return 0


def _draw_confusion(args: argparse.Namespace) -> int:
    bundle = load_bundle(args.truth)
    fit_bundle = read_fit_bundle(args.fit)
    data = bundle_to_fit_data(bundle)
    if data.true_states is None:
        msg = "Truth bundle has no state path; cannot draw confusion."
        raise ValueError(msg)
    perm = align_states(data.true_states, fit_bundle.posterior_states, K)
    aligned = apply_permutation(fit_bundle.posterior_states, perm)
    confusion = state_recovery_confusion(data.true_states, aligned)
    accuracy = state_recovery_accuracy(data.true_states, aligned)
    fig, ax = plt.subplots(figsize=(5, 5))
    plot_state_recovery_confusion(confusion, ax=ax)
    ax.set_title(f"State recovery (accuracy={accuracy:.3f})")
    default = args.fit / "recovery_confusion.png"
    _save_or_show(fig, args.output, default)
    return 0


def _draw_parameter_recovery(args: argparse.Namespace) -> int:
    bundle = load_bundle(args.truth)
    fit_bundle = read_fit_bundle(args.fit)
    data = bundle_to_fit_data(bundle)
    if data.true_states is None:
        msg = "Truth bundle has no state path; cannot draw recovery."
        raise ValueError(msg)
    perm = align_states(data.true_states, fit_bundle.posterior_states, K)
    truth_cfg = _build_dgp_truth_from_config(bundle.config)
    truth_params = dgp_truth_to_fit_init(
        truth_cfg.transitions, truth_cfg.emissions, truth_cfg.population
    )
    aligned_fit = align_fit_to_truth(fit_bundle.params, perm)
    report = parameter_recovery(truth_params, aligned_fit)
    fig, ax = plt.subplots(figsize=(7, 7))
    plot_parameter_recovery(report, ax=ax)
    default = args.fit / "parameter_recovery.png"
    _save_or_show(fig, args.output, default)
    return 0


def _draw_ll_trace(args: argparse.Namespace) -> int:
    fit_bundle = read_fit_bundle(args.fit)
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    fig.suptitle("EM log-likelihood traces (units: total LL across all (i,t))", fontsize=10)
    traces = [list(t) for t in fit_bundle.log_likelihood_traces]
    plot_log_likelihood_trace(traces, ax=list(axes), best_index=fit_bundle.best_index)
    default = args.fit / "ll_trace.png"
    _save_or_show(fig, args.output, default)
    return 0


def run_fit_report(args: argparse.Namespace) -> int:
    """Dispatch one of the four fit-flavored ``report`` subcommands."""
    action = args.action
    if action == "fit-summary":
        return _print_fit_summary(args)
    if action == "recovery-confusion":
        return _draw_confusion(args)
    if action == "parameter-recovery":
        return _draw_parameter_recovery(args)
    if action == "ll-trace":
        return _draw_ll_trace(args)
    msg = f"Unknown fit-report action: {action!r}"
    raise ValueError(msg)