# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""``iohmm-evac bootstrap`` subcommands: fit, shift-sweep, summary."""
from __future__ import annotations
import argparse
import sys
import time
from pathlib import Path
import numpy as np
from iohmm_evac.bootstrap.aggregate import BAND_METRICS, compute_bands
from iohmm_evac.bootstrap.runner import (
load_bootstrap_fits,
run_bootstrap_fits,
)
from iohmm_evac.bootstrap.shift_sweep import (
ShiftSweepResult,
load_sweep_result,
run_shift_sweep,
write_sweep_result,
)
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.inference.em import EMConfig
from iohmm_evac.inference.io import read_fit_bundle
from iohmm_evac.report.loader import load_bundle
from iohmm_evac.scenarios import build_scenario
__all__ = [
"add_bootstrap_subparser",
"format_bootstrap_summary",
"run_bootstrap_command",
]
_DEFAULT_SHIFTS: tuple[int, ...] = (-24, -16, -8, 0, 8, 16, 24)
def _parse_shifts(value: str) -> tuple[int, ...]:
raw = [p.strip() for p in value.split(",")]
items = [int(p) for p in raw if p]
if not items:
msg = "--shifts must list at least one integer"
raise argparse.ArgumentTypeError(msg)
return tuple(items)
def add_bootstrap_subparser(
subparsers: argparse._SubParsersAction[argparse.ArgumentParser],
) -> None:
"""Register ``bootstrap`` and its child subcommands."""
p = subparsers.add_parser(
"bootstrap", help="Parametric bootstrap and warning-shift sweep (Build 4)."
)
actions = p.add_subparsers(dest="action", required=True)
p_fit = actions.add_parser("fit", help="Refit the IO-HMM on B household resamples.")
p_fit.add_argument("--input", type=Path, required=True, help="Baseline observations Parquet.")
p_fit.add_argument(
"--output-dir",
type=Path,
required=True,
help="Directory to write per-replicate fits into.",
)
p_fit.add_argument(
"--n-replicates", type=int, default=50, help="Number of bootstrap replicates (default 50)."
)
p_fit.add_argument(
"--jobs", type=int, default=-1, help="joblib n_jobs (default -1, all cores)."
)
p_fit.add_argument("--seed", type=int, default=0, help="Base RNG seed (default 0).")
p_fit.add_argument(
"--warm-start",
type=Path,
default=None,
help="Optional fit-bundle directory whose theta.toml seeds each replicate's EM.",
)
p_fit.add_argument("--max-iter", type=int, default=200, help="Per-replicate EM iteration cap.")
p_fit.add_argument("--tol", type=float, default=1e-5, help="EM relative-LL tolerance.")
p_fit.add_argument("--quiet", action="store_true", help="Suppress per-replicate progress.")
p_sw = actions.add_parser(
"shift-sweep", help="Run the (replicate × shift) sweep using saved fits."
)
p_sw.add_argument(
"--bootstrap-dir",
type=Path,
required=True,
help="Directory of replicate_NNN/ subdirectories produced by `bootstrap fit`.",
)
p_sw.add_argument(
"--output", type=Path, required=True, help="Long-format sweep parquet to write."
)
p_sw.add_argument(
"--shifts",
type=_parse_shifts,
default=_DEFAULT_SHIFTS,
help=("Comma-separated warning-shift values in hours (default: -24,-16,-8,0,8,16,24)."),
)
p_sw.add_argument("--seed", type=int, default=0, help="Base seed (default 0).")
p_sw.add_argument(
"--scenario",
default="baseline",
help="Scenario whose timeline gets shifted (default: baseline).",
)
p_sw.add_argument(
"--n-households",
type=int,
default=10_000,
help="Households per simulation (default 10000).",
)
p_sw.add_argument(
"--n-hours", type=int, default=120, help="Simulation horizon in hours (default 120)."
)
p_sw.add_argument("--quiet", action="store_true", help="Suppress per-cell progress.")
p_sum = actions.add_parser(
"summary", help="Print a (shift, metric) → median ± [P25,P75] / [P5,P95] table."
)
p_sum.add_argument(
"--input", type=Path, required=True, help="Sweep parquet from `bootstrap shift-sweep`."
)
def _run_fit(args: argparse.Namespace) -> int:
bundle = load_bundle(args.input)
data = bundle_to_fit_data(bundle)
em_config = EMConfig(max_iter=args.max_iter, tol=args.tol, verbose=False)
warm: object | None = None
if args.warm_start is not None:
warm = read_fit_bundle(Path(args.warm_start)).params
if not args.quiet:
warm_label = "warm" if warm is not None else "cold"
print(
f"Bootstrap fit: N={data.n}, replicates={args.n_replicates}, "
f"jobs={args.jobs}, init={warm_label}",
file=sys.stderr,
)
started = time.perf_counter()
fits = run_bootstrap_fits(
data=data,
n_replicates=args.n_replicates,
em_config=em_config,
base_seed=args.seed,
n_jobs=args.jobs,
output_dir=args.output_dir,
warm_start_theta=warm, # type: ignore[arg-type]
)
elapsed = time.perf_counter() - started
if not args.quiet:
avg_iters = float(np.mean([f.iterations for f in fits])) if fits else 0.0
avg_ll = float(np.mean([f.final_log_likelihood for f in fits])) if fits else 0.0
print(
f"Wrote {len(fits)} replicates to {args.output_dir} "
f"in {elapsed:.1f}s (avg iters={avg_iters:.1f}, avg LL={avg_ll:.2f})",
file=sys.stderr,
)
return 0
def _run_shift_sweep(args: argparse.Namespace) -> int:
fits = load_bootstrap_fits(args.bootstrap_dir)
scenario_base = build_scenario(args.scenario)
if not args.quiet:
print(
f"Shift sweep: {len(fits)} replicates × {len(args.shifts)} shifts, "
f"scenario={args.scenario}",
file=sys.stderr,
)
started = time.perf_counter()
result = run_shift_sweep(
bootstrap_fits=fits,
shifts=tuple(args.shifts),
scenario_base=scenario_base,
n_households=args.n_households,
n_hours=args.n_hours,
base_seed=args.seed,
)
write_sweep_result(result, args.output)
elapsed = time.perf_counter() - started
if not args.quiet:
print(
f"Wrote {len(result.rows)} rows to {args.output} in {elapsed:.1f}s",
file=sys.stderr,
)
return 0
def format_bootstrap_summary(result: ShiftSweepResult) -> str:
"""Render a ``(shift, metric)`` median ± [P25, P75] / [P5, P95] table."""
bands = compute_bands(result, percentiles=(5, 25, 50, 75, 95))
header = ["shift", *BAND_METRICS]
rows: list[tuple[str, ...]] = [tuple(header)]
for j, shift in enumerate(bands.shifts):
row: list[str] = [f"{shift:+d}"]
for metric in BAND_METRICS:
p5 = bands.quantile(metric, 5)[j]
p25 = bands.quantile(metric, 25)[j]
p50 = bands.quantile(metric, 50)[j]
p75 = bands.quantile(metric, 75)[j]
p95 = bands.quantile(metric, 95)[j]
row.append(f"{p50:.2f} [{p25:.2f},{p75:.2f}] ({p5:.2f},{p95:.2f})")
rows.append(tuple(row))
widths = [max(len(r[i]) for r in rows) for i in range(len(header))]
pad = 3
lines: list[str] = []
for r in rows:
line = "".join(cell.ljust(width + pad) for cell, width in zip(r, widths, strict=True))
lines.append(line.rstrip())
return "\n".join(lines)
def _run_summary(args: argparse.Namespace) -> int:
result = load_sweep_result(args.input)
sys.stdout.write(format_bootstrap_summary(result) + "\n")
return 0
def run_bootstrap_command(args: argparse.Namespace) -> int:
"""Dispatch ``iohmm-evac bootstrap <action>``."""
if args.action == "fit":
return _run_fit(args)
if args.action == "shift-sweep":
return _run_shift_sweep(args)
if args.action == "summary":
return _run_summary(args)
msg = f"Unknown bootstrap action: {args.action!r}" # pragma: no cover
raise ValueError(msg) # pragma: no cover