src/iohmm_evac/bootstrap_cli.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""``iohmm-evac bootstrap`` subcommands: fit, shift-sweep, summary."""

from __future__ import annotations

import argparse
import sys
import time
from pathlib import Path

import numpy as np

from iohmm_evac.bootstrap.aggregate import BAND_METRICS, compute_bands
from iohmm_evac.bootstrap.runner import (
    load_bootstrap_fits,
    run_bootstrap_fits,
)
from iohmm_evac.bootstrap.shift_sweep import (
    ShiftSweepResult,
    load_sweep_result,
    run_shift_sweep,
    write_sweep_result,
)
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.inference.em import EMConfig
from iohmm_evac.inference.io import read_fit_bundle
from iohmm_evac.report.loader import load_bundle
from iohmm_evac.scenarios import build_scenario

__all__ = [
    "add_bootstrap_subparser",
    "format_bootstrap_summary",
    "run_bootstrap_command",
]


_DEFAULT_SHIFTS: tuple[int, ...] = (-24, -16, -8, 0, 8, 16, 24)


def _parse_shifts(value: str) -> tuple[int, ...]:
    raw = [p.strip() for p in value.split(",")]
    items = [int(p) for p in raw if p]
    if not items:
        msg = "--shifts must list at least one integer"
        raise argparse.ArgumentTypeError(msg)
    return tuple(items)


def add_bootstrap_subparser(
    subparsers: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
    """Register ``bootstrap`` and its child subcommands."""
    p = subparsers.add_parser(
        "bootstrap", help="Parametric bootstrap and warning-shift sweep (Build 4)."
    )
    actions = p.add_subparsers(dest="action", required=True)

    p_fit = actions.add_parser("fit", help="Refit the IO-HMM on B household resamples.")
    p_fit.add_argument("--input", type=Path, required=True, help="Baseline observations Parquet.")
    p_fit.add_argument(
        "--output-dir",
        type=Path,
        required=True,
        help="Directory to write per-replicate fits into.",
    )
    p_fit.add_argument(
        "--n-replicates", type=int, default=50, help="Number of bootstrap replicates (default 50)."
    )
    p_fit.add_argument(
        "--jobs", type=int, default=-1, help="joblib n_jobs (default -1, all cores)."
    )
    p_fit.add_argument("--seed", type=int, default=0, help="Base RNG seed (default 0).")
    p_fit.add_argument(
        "--warm-start",
        type=Path,
        default=None,
        help="Optional fit-bundle directory whose theta.toml seeds each replicate's EM.",
    )
    p_fit.add_argument("--max-iter", type=int, default=200, help="Per-replicate EM iteration cap.")
    p_fit.add_argument("--tol", type=float, default=1e-5, help="EM relative-LL tolerance.")
    p_fit.add_argument("--quiet", action="store_true", help="Suppress per-replicate progress.")

    p_sw = actions.add_parser(
        "shift-sweep", help="Run the (replicate × shift) sweep using saved fits."
    )
    p_sw.add_argument(
        "--bootstrap-dir",
        type=Path,
        required=True,
        help="Directory of replicate_NNN/ subdirectories produced by `bootstrap fit`.",
    )
    p_sw.add_argument(
        "--output", type=Path, required=True, help="Long-format sweep parquet to write."
    )
    p_sw.add_argument(
        "--shifts",
        type=_parse_shifts,
        default=_DEFAULT_SHIFTS,
        help=("Comma-separated warning-shift values in hours (default: -24,-16,-8,0,8,16,24)."),
    )
    p_sw.add_argument("--seed", type=int, default=0, help="Base seed (default 0).")
    p_sw.add_argument(
        "--scenario",
        default="baseline",
        help="Scenario whose timeline gets shifted (default: baseline).",
    )
    p_sw.add_argument(
        "--n-households",
        type=int,
        default=10_000,
        help="Households per simulation (default 10000).",
    )
    p_sw.add_argument(
        "--n-hours", type=int, default=120, help="Simulation horizon in hours (default 120)."
    )
    p_sw.add_argument("--quiet", action="store_true", help="Suppress per-cell progress.")

    p_sum = actions.add_parser(
        "summary", help="Print a (shift, metric) → median ± [P25,P75] / [P5,P95] table."
    )
    p_sum.add_argument(
        "--input", type=Path, required=True, help="Sweep parquet from `bootstrap shift-sweep`."
    )


def _run_fit(args: argparse.Namespace) -> int:
    bundle = load_bundle(args.input)
    data = bundle_to_fit_data(bundle)
    em_config = EMConfig(max_iter=args.max_iter, tol=args.tol, verbose=False)
    warm: object | None = None
    if args.warm_start is not None:
        warm = read_fit_bundle(Path(args.warm_start)).params
    if not args.quiet:
        warm_label = "warm" if warm is not None else "cold"
        print(
            f"Bootstrap fit: N={data.n}, replicates={args.n_replicates}, "
            f"jobs={args.jobs}, init={warm_label}",
            file=sys.stderr,
        )
    started = time.perf_counter()
    fits = run_bootstrap_fits(
        data=data,
        n_replicates=args.n_replicates,
        em_config=em_config,
        base_seed=args.seed,
        n_jobs=args.jobs,
        output_dir=args.output_dir,
        warm_start_theta=warm,  # type: ignore[arg-type]
    )
    elapsed = time.perf_counter() - started
    if not args.quiet:
        avg_iters = float(np.mean([f.iterations for f in fits])) if fits else 0.0
        avg_ll = float(np.mean([f.final_log_likelihood for f in fits])) if fits else 0.0
        print(
            f"Wrote {len(fits)} replicates to {args.output_dir} "
            f"in {elapsed:.1f}s (avg iters={avg_iters:.1f}, avg LL={avg_ll:.2f})",
            file=sys.stderr,
        )
    return 0


def _run_shift_sweep(args: argparse.Namespace) -> int:
    fits = load_bootstrap_fits(args.bootstrap_dir)
    scenario_base = build_scenario(args.scenario)
    if not args.quiet:
        print(
            f"Shift sweep: {len(fits)} replicates × {len(args.shifts)} shifts, "
            f"scenario={args.scenario}",
            file=sys.stderr,
        )
    started = time.perf_counter()
    result = run_shift_sweep(
        bootstrap_fits=fits,
        shifts=tuple(args.shifts),
        scenario_base=scenario_base,
        n_households=args.n_households,
        n_hours=args.n_hours,
        base_seed=args.seed,
    )
    write_sweep_result(result, args.output)
    elapsed = time.perf_counter() - started
    if not args.quiet:
        print(
            f"Wrote {len(result.rows)} rows to {args.output} in {elapsed:.1f}s",
            file=sys.stderr,
        )
    return 0


def format_bootstrap_summary(result: ShiftSweepResult) -> str:
    """Render a ``(shift, metric)`` median ± [P25, P75] / [P5, P95] table."""
    bands = compute_bands(result, percentiles=(5, 25, 50, 75, 95))
    header = ["shift", *BAND_METRICS]
    rows: list[tuple[str, ...]] = [tuple(header)]
    for j, shift in enumerate(bands.shifts):
        row: list[str] = [f"{shift:+d}"]
        for metric in BAND_METRICS:
            p5 = bands.quantile(metric, 5)[j]
            p25 = bands.quantile(metric, 25)[j]
            p50 = bands.quantile(metric, 50)[j]
            p75 = bands.quantile(metric, 75)[j]
            p95 = bands.quantile(metric, 95)[j]
            row.append(f"{p50:.2f} [{p25:.2f},{p75:.2f}] ({p5:.2f},{p95:.2f})")
        rows.append(tuple(row))
    widths = [max(len(r[i]) for r in rows) for i in range(len(header))]
    pad = 3
    lines: list[str] = []
    for r in rows:
        line = "".join(cell.ljust(width + pad) for cell, width in zip(r, widths, strict=True))
        lines.append(line.rstrip())
    return "\n".join(lines)


def _run_summary(args: argparse.Namespace) -> int:
    result = load_sweep_result(args.input)
    sys.stdout.write(format_bootstrap_summary(result) + "\n")
    return 0


def run_bootstrap_command(args: argparse.Namespace) -> int:
    """Dispatch ``iohmm-evac bootstrap <action>``."""
    if args.action == "fit":
        return _run_fit(args)
    if args.action == "shift-sweep":
        return _run_shift_sweep(args)
    if args.action == "summary":
        return _run_summary(args)
    msg = f"Unknown bootstrap action: {args.action!r}"  # pragma: no cover
    raise ValueError(msg)  # pragma: no cover