# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Bootstrap household-resampling primitives."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import pytest
from iohmm_evac.bootstrap.resample import index_fit_data, resample_indices
from iohmm_evac.dgp.simulator import simulate
from iohmm_evac.inference.data import bundle_to_fit_data
from iohmm_evac.io import write_results
from iohmm_evac.params import SimulationConfig
from iohmm_evac.report.loader import load_bundle
@pytest.fixture(scope="module")
def small_bundle(tmp_path_factory: pytest.TempPathFactory) -> Path:
out_dir = tmp_path_factory.mktemp("resample_bundle")
config = SimulationConfig(n_households=200, n_hours=24, seed=0)
rng = np.random.default_rng(config.seed)
result = simulate(config, rng)
obs = out_dir / "obs.parquet"
write_results(result, obs)
return obs
def test_resample_yields_correct_shape() -> None:
seqs = list(resample_indices(n=100, n_resamples=4, seed=0))
assert len(seqs) == 4
for idx in seqs:
assert idx.shape == (100,)
assert int(idx.min()) >= 0
assert int(idx.max()) <= 99
def test_resample_is_reproducible() -> None:
a = list(resample_indices(n=200, n_resamples=3, seed=42))
b = list(resample_indices(n=200, n_resamples=3, seed=42))
for x, y in zip(a, b, strict=True):
np.testing.assert_array_equal(x, y)
def test_distinct_seeds_produce_distinct_sequences() -> None:
a = list(resample_indices(n=200, n_resamples=2, seed=0))
b = list(resample_indices(n=200, n_resamples=2, seed=1))
# Each pair should differ on at least one element.
assert not np.array_equal(a[0], b[0])
def test_average_resample_frequency_is_about_one() -> None:
n = 5_000
n_resamples = 4
counts = np.zeros(n, dtype=np.int64)
for idx in resample_indices(n=n, n_resamples=n_resamples, seed=0):
np.add.at(counts, idx, 1)
mean = counts.mean()
assert abs(mean - n_resamples) < 0.05 * n_resamples
def test_resample_with_replacement_has_repeats() -> None:
seqs = list(resample_indices(n=200, n_resamples=1, seed=0))
idx = seqs[0]
assert len(np.unique(idx)) < idx.shape[0]
def test_index_fit_data_shape(small_bundle: Path) -> None:
data = bundle_to_fit_data(load_bundle(small_bundle))
indices = np.array([0, 0, 1, 5, 199, 199], dtype=np.int64)
sub = index_fit_data(data, indices)
assert sub.n == 6
assert sub.t_total == data.t_total
assert sub.inputs.shape == (6, data.t_total + 1, data.f)
assert sub.departure.shape == (6, data.t_total + 1)
def test_index_fit_data_rejects_out_of_bounds(small_bundle: Path) -> None:
data = bundle_to_fit_data(load_bundle(small_bundle))
bad = np.array([0, data.n], dtype=np.int64)
with pytest.raises(ValueError, match="out of bounds"):
index_fit_data(data, bad)
def test_resample_indices_validates_n() -> None:
with pytest.raises(ValueError, match="positive"):
list(resample_indices(n=0, n_resamples=1, seed=0))
with pytest.raises(ValueError, match="non-negative"):
list(resample_indices(n=10, n_resamples=-1, seed=0))