# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""``iohmm-evac report`` subcommand: render diagnostic plots from a saved run."""
from __future__ import annotations
import argparse
import sys
from collections.abc import Callable
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from iohmm_evac.bootstrap.aggregate import BAND_METRICS, compute_bands
from iohmm_evac.bootstrap.shift_sweep import load_sweep_result
from iohmm_evac.report.constants import DEFAULT_DPI
from iohmm_evac.report.loader import SimulationBundle, load_bundle
from iohmm_evac.report.plots import (
plot_bootstrap_bands,
plot_cumulative_departures,
plot_emission_summary,
plot_household_trajectories,
plot_state_occupancy,
plot_sweep_departures,
plot_sweep_network,
)
from iohmm_evac.report.recovery_cli import (
FIT_SUBCOMMANDS,
add_fit_report_subparsers,
run_fit_report,
)
from iohmm_evac.report.summary import bundle_summary, format_summary
from iohmm_evac.sweep import SweepResult, load_sweep
__all__ = ["add_report_subparser", "run_report"]
_SWEEP_SUBCOMMANDS: frozenset[str] = frozenset({"sweep-departures", "sweep-network", "sweep-all"})
_BOOTSTRAP_SUBCOMMANDS: frozenset[str] = frozenset({"bootstrap-bands"})
_DEFAULT_HOUSEHOLD_IDS: tuple[int, ...] = (0, 1, 2)
_ALL_FILENAMES: dict[str, str] = {
"occupancy": "occupancy.png",
"departures": "departures.png",
"trajectories": "trajectories.png",
"emissions": "emissions.png",
}
def _parse_household_ids(value: str) -> list[int]:
"""Parse a comma-separated list of household IDs from the CLI."""
if not value:
msg = "--household-ids must not be empty"
raise argparse.ArgumentTypeError(msg)
out: list[int] = []
for raw in value.split(","):
token = raw.strip()
if not token:
continue
try:
out.append(int(token))
except ValueError as exc:
msg = f"Invalid household id: {token!r}"
raise argparse.ArgumentTypeError(msg) from exc
if not out:
msg = "--household-ids must not be empty"
raise argparse.ArgumentTypeError(msg)
return out
def add_report_subparser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
"""Register ``report`` and its child subcommands on the top-level parser."""
p_report = subparsers.add_parser("report", help="Render diagnostic plots from a saved run.")
actions = p_report.add_subparsers(dest="action", required=True)
for name, helptext in [
("occupancy", "Stacked area chart of state shares over time."),
("departures", "Cumulative departure share over time."),
("emissions", "Per-state emission-mean summary bars."),
]:
sp = actions.add_parser(name, help=helptext)
_add_common_io_flags(sp)
sp_traj = actions.add_parser("trajectories", help="Per-household forecast/state/displacement.")
_add_common_io_flags(sp_traj)
default_ids = ",".join(str(i) for i in _DEFAULT_HOUSEHOLD_IDS)
sp_traj.add_argument(
"--household-ids",
type=_parse_household_ids,
default=list(_DEFAULT_HOUSEHOLD_IDS),
help=f"Comma-separated household IDs (default: {default_ids}).",
)
sp_summary = actions.add_parser("summary", help="Print sanity-check metrics as a table.")
sp_summary.add_argument("--input", type=Path, required=True, help="Observations Parquet path.")
sp_all = actions.add_parser("all", help="Render every diagnostic plot at once.")
sp_all.add_argument("--input", type=Path, required=True, help="Observations Parquet path.")
sp_all.add_argument(
"--output-dir", type=Path, required=True, help="Directory to write PNGs into."
)
sp_all.add_argument(
"--household-ids",
type=_parse_household_ids,
default=list(_DEFAULT_HOUSEHOLD_IDS),
help="Comma-separated IDs for the trajectories plot.",
)
add_fit_report_subparsers(actions)
_add_sweep_report_subparsers(actions)
_add_bootstrap_report_subparsers(actions)
def _add_bootstrap_report_subparsers(
actions: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
"""Register the ``bootstrap-bands`` action."""
sp = actions.add_parser(
"bootstrap-bands", help="Median + quantile bands across warning shifts (Fig. 6)."
)
sp.add_argument(
"--input", type=Path, required=True, help="Sweep parquet from `bootstrap shift-sweep`."
)
sp.add_argument("--output", type=Path, required=True, help="PNG output path.")
sp.add_argument(
"--metric",
choices=BAND_METRICS,
default="failed_evacuation_count",
help="Metric to plot (default: failed_evacuation_count).",
)
def _add_sweep_report_subparsers(
actions: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
"""Register the ``sweep-departures``, ``sweep-network``, ``sweep-all`` actions."""
sp_dep = actions.add_parser(
"sweep-departures", help="Cross-scenario cumulative departures (Fig. 4)."
)
sp_dep.add_argument("--input-dir", type=Path, required=True, help="Sweep directory.")
sp_dep.add_argument(
"--output",
type=Path,
default=None,
help="PNG output (default: <input-dir>/sweep_departures.png).",
)
sp_net = actions.add_parser(
"sweep-network", help="Cross-scenario network metrics 2x2 panel (Fig. 5)."
)
sp_net.add_argument("--input-dir", type=Path, required=True, help="Sweep directory.")
sp_net.add_argument(
"--output",
type=Path,
default=None,
help="PNG output (default: <input-dir>/sweep_network.png).",
)
sp_all = actions.add_parser(
"sweep-all", help="Render both sweep_departures.png and sweep_network.png."
)
sp_all.add_argument("--input-dir", type=Path, required=True, help="Sweep directory.")
sp_all.add_argument(
"--output-dir", type=Path, required=True, help="Directory to write PNGs into."
)
def _add_common_io_flags(p: argparse.ArgumentParser) -> None:
p.add_argument("--input", type=Path, required=True, help="Observations Parquet path.")
p.add_argument(
"--output",
type=Path,
default=None,
help="PNG output path. Defaults to <input-stem>.<plot>.png next to the input.",
)
p.add_argument(
"--show",
action="store_true",
help="Open an interactive window via plt.show() (skip if running headless).",
)
def _save_figure(fig: Figure, output: Path) -> None:
output.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(output, dpi=DEFAULT_DPI, bbox_inches="tight")
def _default_output_path(input_path: Path, plot_name: str) -> Path:
return input_path.parent / f"{input_path.stem}.{plot_name}.png"
def _render_single(
bundle: SimulationBundle,
plot_name: str,
args: argparse.Namespace,
draw: Callable[[SimulationBundle], Figure],
) -> Path | None:
fig = draw(bundle)
if args.show:
plt.show()
plt.close(fig)
return None
output = args.output if args.output is not None else _default_output_path(args.input, plot_name)
_save_figure(fig, output)
plt.close(fig)
return output
def _draw_occupancy(bundle: SimulationBundle) -> Figure:
fig, ax = plt.subplots(figsize=(10, 4))
plot_state_occupancy(bundle, ax=ax)
return fig
def _draw_departures(bundle: SimulationBundle) -> Figure:
fig, ax = plt.subplots(figsize=(10, 4))
plot_cumulative_departures(bundle, ax=ax)
return fig
def _draw_emissions(bundle: SimulationBundle) -> Figure:
fig, ax = plt.subplots(figsize=(8, 4))
plot_emission_summary(bundle, ax=ax)
return fig
def _draw_trajectories(bundle: SimulationBundle, household_ids: list[int]) -> Figure:
fig, axes = plt.subplots(
len(household_ids), 1, figsize=(10, 2.5 * len(household_ids)), sharex=True
)
axes_seq = [axes] if len(household_ids) == 1 else list(axes)
plot_household_trajectories(bundle, household_ids, ax=axes_seq)
return fig
def _run_all(bundle: SimulationBundle, args: argparse.Namespace) -> list[Path]:
out_dir: Path = args.output_dir
out_dir.mkdir(parents=True, exist_ok=True)
written: list[Path] = []
for name, draw in [
("occupancy", _draw_occupancy),
("departures", _draw_departures),
("emissions", _draw_emissions),
]:
fig = draw(bundle)
target = out_dir / _ALL_FILENAMES[name]
_save_figure(fig, target)
plt.close(fig)
written.append(target)
fig = _draw_trajectories(bundle, list(args.household_ids))
target = out_dir / _ALL_FILENAMES["trajectories"]
_save_figure(fig, target)
plt.close(fig)
written.append(target)
return written
def _draw_sweep_departures(sweep: SweepResult) -> Figure:
fig, ax = plt.subplots(figsize=(10, 4))
plot_sweep_departures(sweep, ax=ax)
return fig
def _draw_sweep_network(sweep: SweepResult) -> Figure:
fig, axes = plt.subplots(2, 2, figsize=(10, 6))
plot_sweep_network(sweep, ax=axes)
fig.tight_layout()
return fig
def _run_sweep_report(args: argparse.Namespace) -> int:
action = args.action
if action == "sweep-departures":
sweep = load_sweep(args.input_dir)
out = args.output if args.output is not None else args.input_dir / "sweep_departures.png"
fig = _draw_sweep_departures(sweep)
_save_figure(fig, out)
plt.close(fig)
return 0
if action == "sweep-network":
sweep = load_sweep(args.input_dir)
out = args.output if args.output is not None else args.input_dir / "sweep_network.png"
fig = _draw_sweep_network(sweep)
_save_figure(fig, out)
plt.close(fig)
return 0
if action == "sweep-all":
sweep = load_sweep(args.input_dir)
out_dir: Path = args.output_dir
out_dir.mkdir(parents=True, exist_ok=True)
for name, draw in (
("sweep_departures.png", _draw_sweep_departures),
("sweep_network.png", _draw_sweep_network),
):
fig = draw(sweep)
_save_figure(fig, out_dir / name)
plt.close(fig)
return 0
msg = f"Unknown sweep report action: {action!r}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
def _run_bootstrap_report(args: argparse.Namespace) -> int:
sweep_result = load_sweep_result(args.input)
band_result = compute_bands(sweep_result, percentiles=(5, 25, 50, 75, 95))
fig, ax = plt.subplots(figsize=(8, 4.5))
plot_bootstrap_bands(band_result, metric=args.metric, ax=ax)
fig.tight_layout()
_save_figure(fig, args.output)
plt.close(fig)
return 0
def run_report(args: argparse.Namespace) -> int:
"""Dispatch ``iohmm-evac report <action>`` after the parser has run."""
action = args.action
if action in FIT_SUBCOMMANDS:
return run_fit_report(args)
if action in _SWEEP_SUBCOMMANDS:
return _run_sweep_report(args)
if action in _BOOTSTRAP_SUBCOMMANDS:
return _run_bootstrap_report(args)
bundle = load_bundle(args.input)
if action == "all":
_run_all(bundle, args)
return 0
if action == "summary":
sys.stdout.write(format_summary(bundle_summary(bundle)) + "\n")
return 0
if action == "occupancy":
_render_single(bundle, "occupancy", args, _draw_occupancy)
return 0
if action == "departures":
_render_single(bundle, "departures", args, _draw_departures)
return 0
if action == "emissions":
_render_single(bundle, "emissions", args, _draw_emissions)
return 0
if action == "trajectories":
ids = list(args.household_ids)
_render_single(
bundle,
"trajectories",
args,
lambda b: _draw_trajectories(b, ids),
)
return 0
msg = f"Unknown report action: {action!r}"
raise ValueError(msg)