tests/test_bootstrap_runner.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Bootstrap runner tests: parallel EM fits + warm-start sanity check."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest

from iohmm_evac.bootstrap.runner import (
    BootstrapFit,
    load_bootstrap_fits,
    run_bootstrap_fits,
)
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.inference.em import EMConfig
from iohmm_evac.inference.fit_params import dgp_truth_to_fit_init
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig
from iohmm_evac.report.loader import load_bundle


@pytest.fixture(scope="module")
def small_data(tmp_path_factory: pytest.TempPathFactory):  # type: ignore[no-untyped-def]
    out = tmp_path_factory.mktemp("runner")
    config = SimulationConfig(n_households=200, n_hours=24, seed=0)
    rng = np.random.default_rng(config.seed)
    result = simulate(config, rng)
    obs = out / "obs.parquet"
    write_results(result, obs)
    bundle = load_bundle(obs)
    data = bundle_to_fit_data(bundle)
    truth = dgp_truth_to_fit_init(config.transitions, config.emissions, config.population)
    return data, truth


def test_run_bootstrap_fits_writes_replicate_dirs(tmp_path: Path, small_data) -> None:  # type: ignore[no-untyped-def]
    data, _ = small_data
    out_dir = tmp_path / "fits"
    em_config = EMConfig(max_iter=3, tol=1e-3, verbose=False)
    fits = run_bootstrap_fits(
        data=data,
        n_replicates=2,
        em_config=em_config,
        base_seed=0,
        n_jobs=2,
        output_dir=out_dir,
    )
    assert len(fits) == 2
    for fit in fits:
        rep_dir = out_dir / f"replicate_{fit.replicate_id:03d}"
        for name in ("theta.toml", "metadata.toml", "indices.parquet"):
            assert (rep_dir / name).exists(), f"missing {name} in {rep_dir}"


def test_log_likelihoods_are_finite(tmp_path: Path, small_data) -> None:  # type: ignore[no-untyped-def]
    data, _ = small_data
    out_dir = tmp_path / "fits"
    fits = run_bootstrap_fits(
        data=data,
        n_replicates=2,
        em_config=EMConfig(max_iter=3, tol=1e-3, verbose=False),
        base_seed=0,
        n_jobs=1,
        output_dir=out_dir,
    )
    for fit in fits:
        assert np.isfinite(fit.final_log_likelihood)
        assert fit.iterations >= 1


def test_load_bootstrap_fits_round_trips(tmp_path: Path, small_data) -> None:  # type: ignore[no-untyped-def]
    data, _ = small_data
    out_dir = tmp_path / "fits"
    fits = run_bootstrap_fits(
        data=data,
        n_replicates=2,
        em_config=EMConfig(max_iter=2, tol=1e-3, verbose=False),
        base_seed=0,
        n_jobs=1,
        output_dir=out_dir,
    )
    loaded = load_bootstrap_fits(out_dir)
    assert len(loaded) == len(fits)
    for original, reloaded in zip(fits, loaded, strict=True):
        assert original.replicate_id == reloaded.replicate_id
        assert original.iterations == reloaded.iterations
        np.testing.assert_array_equal(original.indices, reloaded.indices)
        np.testing.assert_allclose(
            original.params.transitions.beta, reloaded.params.transitions.beta
        )


def test_warm_start_reduces_iteration_count(tmp_path: Path, small_data) -> None:  # type: ignore[no-untyped-def]
    """Sanity check, not a strict assertion: warm beats cold on average."""
    data, truth = small_data
    em_config = EMConfig(max_iter=20, tol=1e-4, verbose=False)
    cold = run_bootstrap_fits(
        data=data,
        n_replicates=2,
        em_config=em_config,
        base_seed=0,
        n_jobs=1,
        output_dir=tmp_path / "cold",
    )
    warm = run_bootstrap_fits(
        data=data,
        n_replicates=2,
        em_config=em_config,
        base_seed=0,
        n_jobs=1,
        output_dir=tmp_path / "warm",
        warm_start_theta=truth,
    )
    cold_avg = float(np.mean([f.iterations for f in cold]))
    warm_avg = float(np.mean([f.iterations for f in warm]))
    assert warm_avg <= cold_avg, (
        f"warm avg iters ({warm_avg}) should be <= cold avg iters ({cold_avg})"
    )


def test_n_replicates_must_be_positive(tmp_path: Path, small_data) -> None:  # type: ignore[no-untyped-def]
    data, _ = small_data
    with pytest.raises(ValueError, match=">= 1"):
        run_bootstrap_fits(
            data=data,
            n_replicates=0,
            em_config=EMConfig(max_iter=1, tol=1e-3, verbose=False),
            base_seed=0,
            n_jobs=1,
            output_dir=tmp_path / "empty",
        )


def test_load_bootstrap_fits_missing_dir(tmp_path: Path) -> None:
    with pytest.raises(FileNotFoundError):
        load_bootstrap_fits(tmp_path / "no-such")


def test_load_bootstrap_fits_empty_dir(tmp_path: Path) -> None:
    out = tmp_path / "empty_fits"
    out.mkdir()
    with pytest.raises(FileNotFoundError, match="No replicate"):
        load_bootstrap_fits(out)


def test_bootstrap_fit_dataclass_is_frozen() -> None:
    indices = np.zeros(3, dtype=np.int64)
    fit = BootstrapFit(
        replicate_id=0,
        params=None,  # type: ignore[arg-type]
        final_log_likelihood=-1.0,
        iterations=1,
        converged=False,
        indices=indices,
    )
    with pytest.raises(AttributeError):
        fit.replicate_id = 99  # type: ignore[misc]