tests/test_bootstrap_resample.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Bootstrap household-resampling primitives."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest

from iohmm_evac.bootstrap.resample import index_fit_data, resample_indices
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig
from iohmm_evac.report.loader import load_bundle


@pytest.fixture(scope="module")
def small_bundle(tmp_path_factory: pytest.TempPathFactory) -> Path:
    out_dir = tmp_path_factory.mktemp("resample_bundle")
    config = SimulationConfig(n_households=200, n_hours=24, seed=0)
    rng = np.random.default_rng(config.seed)
    result = simulate(config, rng)
    obs = out_dir / "obs.parquet"
    write_results(result, obs)
    return obs


def test_resample_yields_correct_shape() -> None:
    seqs = list(resample_indices(n=100, n_resamples=4, seed=0))
    assert len(seqs) == 4
    for idx in seqs:
        assert idx.shape == (100,)
        assert int(idx.min()) >= 0
        assert int(idx.max()) <= 99


def test_resample_is_reproducible() -> None:
    a = list(resample_indices(n=200, n_resamples=3, seed=42))
    b = list(resample_indices(n=200, n_resamples=3, seed=42))
    for x, y in zip(a, b, strict=True):
        np.testing.assert_array_equal(x, y)


def test_distinct_seeds_produce_distinct_sequences() -> None:
    a = list(resample_indices(n=200, n_resamples=2, seed=0))
    b = list(resample_indices(n=200, n_resamples=2, seed=1))
    # Each pair should differ on at least one element.
    assert not np.array_equal(a[0], b[0])


def test_average_resample_frequency_is_about_one() -> None:
    n = 5_000
    n_resamples = 4
    counts = np.zeros(n, dtype=np.int64)
    for idx in resample_indices(n=n, n_resamples=n_resamples, seed=0):
        np.add.at(counts, idx, 1)
    mean = counts.mean()
    assert abs(mean - n_resamples) < 0.05 * n_resamples


def test_resample_with_replacement_has_repeats() -> None:
    seqs = list(resample_indices(n=200, n_resamples=1, seed=0))
    idx = seqs[0]
    assert len(np.unique(idx)) < idx.shape[0]


def test_index_fit_data_shape(small_bundle: Path) -> None:
    data = bundle_to_fit_data(load_bundle(small_bundle))
    indices = np.array([0, 0, 1, 5, 199, 199], dtype=np.int64)
    sub = index_fit_data(data, indices)
    assert sub.n == 6
    assert sub.t_total == data.t_total
    assert sub.inputs.shape == (6, data.t_total + 1, data.f)
    assert sub.departure.shape == (6, data.t_total + 1)


def test_index_fit_data_rejects_out_of_bounds(small_bundle: Path) -> None:
    data = bundle_to_fit_data(load_bundle(small_bundle))
    bad = np.array([0, data.n], dtype=np.int64)
    with pytest.raises(ValueError, match="out of bounds"):
        index_fit_data(data, bad)


def test_resample_indices_validates_n() -> None:
    with pytest.raises(ValueError, match="positive"):
        list(resample_indices(n=0, n_resamples=1, seed=0))
    with pytest.raises(ValueError, match="non-negative"):
        list(resample_indices(n=10, n_resamples=-1, seed=0))