src/iohmm_evac/report/cli.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""``iohmm-evac report`` subcommand: render diagnostic plots from a saved run."""

from __future__ import annotations

import argparse
import sys
from collections.abc import Callable
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib.figure import Figure

from iohmm_evac.bootstrap.aggregate import BAND_METRICS, compute_bands
from iohmm_evac.bootstrap.shift_sweep import load_sweep_result
from iohmm_evac.report.constants import DEFAULT_DPI
from iohmm_evac.report.loader import SimulationBundle, load_bundle
from iohmm_evac.report.plots import (
    plot_bootstrap_bands,
    plot_cumulative_departures,
    plot_emission_summary,
    plot_household_trajectories,
    plot_state_occupancy,
    plot_sweep_departures,
    plot_sweep_network,
)
from iohmm_evac.report.recovery_cli import (
    FIT_SUBCOMMANDS,
    add_fit_report_subparsers,
    run_fit_report,
)
from iohmm_evac.report.summary import bundle_summary, format_summary
from iohmm_evac.sweep import SweepResult, load_sweep

__all__ = ["add_report_subparser", "run_report"]


_SWEEP_SUBCOMMANDS: frozenset[str] = frozenset({"sweep-departures", "sweep-network", "sweep-all"})
_BOOTSTRAP_SUBCOMMANDS: frozenset[str] = frozenset({"bootstrap-bands"})


_DEFAULT_HOUSEHOLD_IDS: tuple[int, ...] = (0, 1, 2)
_ALL_FILENAMES: dict[str, str] = {
    "occupancy": "occupancy.png",
    "departures": "departures.png",
    "trajectories": "trajectories.png",
    "emissions": "emissions.png",
}


def _parse_household_ids(value: str) -> list[int]:
    """Parse a comma-separated list of household IDs from the CLI."""
    if not value:
        msg = "--household-ids must not be empty"
        raise argparse.ArgumentTypeError(msg)
    out: list[int] = []
    for raw in value.split(","):
        token = raw.strip()
        if not token:
            continue
        try:
            out.append(int(token))
        except ValueError as exc:
            msg = f"Invalid household id: {token!r}"
            raise argparse.ArgumentTypeError(msg) from exc
    if not out:
        msg = "--household-ids must not be empty"
        raise argparse.ArgumentTypeError(msg)
    return out


def add_report_subparser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
    """Register ``report`` and its child subcommands on the top-level parser."""
    p_report = subparsers.add_parser("report", help="Render diagnostic plots from a saved run.")
    actions = p_report.add_subparsers(dest="action", required=True)

    for name, helptext in [
        ("occupancy", "Stacked area chart of state shares over time."),
        ("departures", "Cumulative departure share over time."),
        ("emissions", "Per-state emission-mean summary bars."),
    ]:
        sp = actions.add_parser(name, help=helptext)
        _add_common_io_flags(sp)

    sp_traj = actions.add_parser("trajectories", help="Per-household forecast/state/displacement.")
    _add_common_io_flags(sp_traj)
    default_ids = ",".join(str(i) for i in _DEFAULT_HOUSEHOLD_IDS)
    sp_traj.add_argument(
        "--household-ids",
        type=_parse_household_ids,
        default=list(_DEFAULT_HOUSEHOLD_IDS),
        help=f"Comma-separated household IDs (default: {default_ids}).",
    )

    sp_summary = actions.add_parser("summary", help="Print sanity-check metrics as a table.")
    sp_summary.add_argument("--input", type=Path, required=True, help="Observations Parquet path.")

    sp_all = actions.add_parser("all", help="Render every diagnostic plot at once.")
    sp_all.add_argument("--input", type=Path, required=True, help="Observations Parquet path.")
    sp_all.add_argument(
        "--output-dir", type=Path, required=True, help="Directory to write PNGs into."
    )
    sp_all.add_argument(
        "--household-ids",
        type=_parse_household_ids,
        default=list(_DEFAULT_HOUSEHOLD_IDS),
        help="Comma-separated IDs for the trajectories plot.",
    )

    add_fit_report_subparsers(actions)
    _add_sweep_report_subparsers(actions)
    _add_bootstrap_report_subparsers(actions)


def _add_bootstrap_report_subparsers(
    actions: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
    """Register the ``bootstrap-bands`` action."""
    sp = actions.add_parser(
        "bootstrap-bands", help="Median + quantile bands across warning shifts (Fig. 6)."
    )
    sp.add_argument(
        "--input", type=Path, required=True, help="Sweep parquet from `bootstrap shift-sweep`."
    )
    sp.add_argument("--output", type=Path, required=True, help="PNG output path.")
    sp.add_argument(
        "--metric",
        choices=BAND_METRICS,
        default="failed_evacuation_count",
        help="Metric to plot (default: failed_evacuation_count).",
    )


def _add_sweep_report_subparsers(
    actions: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
    """Register the ``sweep-departures``, ``sweep-network``, ``sweep-all`` actions."""
    sp_dep = actions.add_parser(
        "sweep-departures", help="Cross-scenario cumulative departures (Fig. 4)."
    )
    sp_dep.add_argument("--input-dir", type=Path, required=True, help="Sweep directory.")
    sp_dep.add_argument(
        "--output",
        type=Path,
        default=None,
        help="PNG output (default: <input-dir>/sweep_departures.png).",
    )

    sp_net = actions.add_parser(
        "sweep-network", help="Cross-scenario network metrics 2x2 panel (Fig. 5)."
    )
    sp_net.add_argument("--input-dir", type=Path, required=True, help="Sweep directory.")
    sp_net.add_argument(
        "--output",
        type=Path,
        default=None,
        help="PNG output (default: <input-dir>/sweep_network.png).",
    )

    sp_all = actions.add_parser(
        "sweep-all", help="Render both sweep_departures.png and sweep_network.png."
    )
    sp_all.add_argument("--input-dir", type=Path, required=True, help="Sweep directory.")
    sp_all.add_argument(
        "--output-dir", type=Path, required=True, help="Directory to write PNGs into."
    )


def _add_common_io_flags(p: argparse.ArgumentParser) -> None:
    p.add_argument("--input", type=Path, required=True, help="Observations Parquet path.")
    p.add_argument(
        "--output",
        type=Path,
        default=None,
        help="PNG output path. Defaults to <input-stem>.<plot>.png next to the input.",
    )
    p.add_argument(
        "--show",
        action="store_true",
        help="Open an interactive window via plt.show() (skip if running headless).",
    )


def _save_figure(fig: Figure, output: Path) -> None:
    output.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(output, dpi=DEFAULT_DPI, bbox_inches="tight")


def _default_output_path(input_path: Path, plot_name: str) -> Path:
    return input_path.parent / f"{input_path.stem}.{plot_name}.png"


def _render_single(
    bundle: SimulationBundle,
    plot_name: str,
    args: argparse.Namespace,
    draw: Callable[[SimulationBundle], Figure],
) -> Path | None:
    fig = draw(bundle)
    if args.show:
        plt.show()
        plt.close(fig)
        return None
    output = args.output if args.output is not None else _default_output_path(args.input, plot_name)
    _save_figure(fig, output)
    plt.close(fig)
    return output


def _draw_occupancy(bundle: SimulationBundle) -> Figure:
    fig, ax = plt.subplots(figsize=(10, 4))
    plot_state_occupancy(bundle, ax=ax)
    return fig


def _draw_departures(bundle: SimulationBundle) -> Figure:
    fig, ax = plt.subplots(figsize=(10, 4))
    plot_cumulative_departures(bundle, ax=ax)
    return fig


def _draw_emissions(bundle: SimulationBundle) -> Figure:
    fig, ax = plt.subplots(figsize=(8, 4))
    plot_emission_summary(bundle, ax=ax)
    return fig


def _draw_trajectories(bundle: SimulationBundle, household_ids: list[int]) -> Figure:
    fig, axes = plt.subplots(
        len(household_ids), 1, figsize=(10, 2.5 * len(household_ids)), sharex=True
    )
    axes_seq = [axes] if len(household_ids) == 1 else list(axes)
    plot_household_trajectories(bundle, household_ids, ax=axes_seq)
    return fig


def _run_all(bundle: SimulationBundle, args: argparse.Namespace) -> list[Path]:
    out_dir: Path = args.output_dir
    out_dir.mkdir(parents=True, exist_ok=True)
    written: list[Path] = []
    for name, draw in [
        ("occupancy", _draw_occupancy),
        ("departures", _draw_departures),
        ("emissions", _draw_emissions),
    ]:
        fig = draw(bundle)
        target = out_dir / _ALL_FILENAMES[name]
        _save_figure(fig, target)
        plt.close(fig)
        written.append(target)
    fig = _draw_trajectories(bundle, list(args.household_ids))
    target = out_dir / _ALL_FILENAMES["trajectories"]
    _save_figure(fig, target)
    plt.close(fig)
    written.append(target)
    return written


def _draw_sweep_departures(sweep: SweepResult) -> Figure:
    fig, ax = plt.subplots(figsize=(10, 4))
    plot_sweep_departures(sweep, ax=ax)
    return fig


def _draw_sweep_network(sweep: SweepResult) -> Figure:
    fig, axes = plt.subplots(2, 2, figsize=(10, 6))
    plot_sweep_network(sweep, ax=axes)
    fig.tight_layout()
    return fig


def _run_sweep_report(args: argparse.Namespace) -> int:
    action = args.action
    if action == "sweep-departures":
        sweep = load_sweep(args.input_dir)
        out = args.output if args.output is not None else args.input_dir / "sweep_departures.png"
        fig = _draw_sweep_departures(sweep)
        _save_figure(fig, out)
        plt.close(fig)
        return 0
    if action == "sweep-network":
        sweep = load_sweep(args.input_dir)
        out = args.output if args.output is not None else args.input_dir / "sweep_network.png"
        fig = _draw_sweep_network(sweep)
        _save_figure(fig, out)
        plt.close(fig)
        return 0
    if action == "sweep-all":
        sweep = load_sweep(args.input_dir)
        out_dir: Path = args.output_dir
        out_dir.mkdir(parents=True, exist_ok=True)
        for name, draw in (
            ("sweep_departures.png", _draw_sweep_departures),
            ("sweep_network.png", _draw_sweep_network),
        ):
            fig = draw(sweep)
            _save_figure(fig, out_dir / name)
            plt.close(fig)
        return 0
    msg = f"Unknown sweep report action: {action!r}"  # pragma: no cover
    raise ValueError(msg)  # pragma: no cover


def _run_bootstrap_report(args: argparse.Namespace) -> int:
    sweep_result = load_sweep_result(args.input)
    band_result = compute_bands(sweep_result, percentiles=(5, 25, 50, 75, 95))
    fig, ax = plt.subplots(figsize=(8, 4.5))
    plot_bootstrap_bands(band_result, metric=args.metric, ax=ax)
    fig.tight_layout()
    _save_figure(fig, args.output)
    plt.close(fig)
    return 0


def run_report(args: argparse.Namespace) -> int:
    """Dispatch ``iohmm-evac report <action>`` after the parser has run."""
    action = args.action
    if action in FIT_SUBCOMMANDS:
        return run_fit_report(args)
    if action in _SWEEP_SUBCOMMANDS:
        return _run_sweep_report(args)
    if action in _BOOTSTRAP_SUBCOMMANDS:
        return _run_bootstrap_report(args)
    bundle = load_bundle(args.input)
    if action == "all":
        _run_all(bundle, args)
        return 0
    if action == "summary":
        sys.stdout.write(format_summary(bundle_summary(bundle)) + "\n")
        return 0
    if action == "occupancy":
        _render_single(bundle, "occupancy", args, _draw_occupancy)
        return 0
    if action == "departures":
        _render_single(bundle, "departures", args, _draw_departures)
        return 0
    if action == "emissions":
        _render_single(bundle, "emissions", args, _draw_emissions)
        return 0
    if action == "trajectories":
        ids = list(args.household_ids)
        _render_single(
            bundle,
            "trajectories",
            args,
            lambda b: _draw_trajectories(b, ids),
        )
        return 0
    msg = f"Unknown report action: {action!r}"
    raise ValueError(msg)