tests/test_bootstrap_cli.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""End-to-end CLI smoke test on a tiny configuration."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest

from iohmm_evac.bootstrap.shift_sweep import load_sweep_result
from iohmm_evac.bootstrap_cli import format_bootstrap_summary
from iohmm_evac.cli import main
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig


@pytest.fixture(scope="module")
def baseline_obs(tmp_path_factory: pytest.TempPathFactory) -> Path:
    out = tmp_path_factory.mktemp("bootstrap_cli")
    config = SimulationConfig(n_households=200, n_hours=24, seed=0)
    rng = np.random.default_rng(config.seed)
    result = simulate(config, rng)
    obs = out / "baseline.parquet"
    write_results(result, obs)
    return obs


def test_bootstrap_cli_end_to_end(tmp_path: Path, baseline_obs: Path) -> None:
    fits_dir = tmp_path / "fits"
    sweep_path = tmp_path / "sweep.parquet"
    bands_png = tmp_path / "bands.png"

    rc = main(
        [
            "bootstrap",
            "fit",
            "--input",
            str(baseline_obs),
            "--output-dir",
            str(fits_dir),
            "--n-replicates",
            "2",
            "--jobs",
            "1",
            "--max-iter",
            "3",
            "--tol",
            "1e-3",
            "--quiet",
        ]
    )
    assert rc == 0
    rep_dirs = sorted(fits_dir.glob("replicate_*"))
    assert len(rep_dirs) == 2
    for rep in rep_dirs:
        assert (rep / "theta.toml").exists()
        assert (rep / "metadata.toml").exists()
        assert (rep / "indices.parquet").exists()

    rc = main(
        [
            "bootstrap",
            "shift-sweep",
            "--bootstrap-dir",
            str(fits_dir),
            "--output",
            str(sweep_path),
            "--shifts=-8,0,8",
            "--n-households",
            "200",
            "--n-hours",
            "24",
            "--quiet",
        ]
    )
    assert rc == 0
    assert sweep_path.exists()
    sweep_result = load_sweep_result(sweep_path)
    assert len(sweep_result.rows) == 2 * 3

    rc = main(
        [
            "report",
            "bootstrap-bands",
            "--input",
            str(sweep_path),
            "--output",
            str(bands_png),
        ]
    )
    assert rc == 0
    assert bands_png.exists()
    assert bands_png.stat().st_size > 0


def test_bootstrap_summary_table_smoke(tmp_path: Path, baseline_obs: Path) -> None:
    fits_dir = tmp_path / "fits"
    sweep_path = tmp_path / "sweep.parquet"
    main(
        [
            "bootstrap",
            "fit",
            "--input",
            str(baseline_obs),
            "--output-dir",
            str(fits_dir),
            "--n-replicates",
            "2",
            "--jobs",
            "1",
            "--max-iter",
            "2",
            "--tol",
            "1e-3",
            "--quiet",
        ]
    )
    main(
        [
            "bootstrap",
            "shift-sweep",
            "--bootstrap-dir",
            str(fits_dir),
            "--output",
            str(sweep_path),
            "--shifts=-8,0",
            "--n-households",
            "200",
            "--n-hours",
            "24",
            "--quiet",
        ]
    )
    text = format_bootstrap_summary(load_sweep_result(sweep_path))
    assert "shift" in text
    assert "failed_evacuation_count" in text
    assert "+0" in text or "0" in text