# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Household-level bootstrap resampling.
The resample is just an index vector; downstream code consumes
``(observations[idx], inputs[idx], population[idx])``. We never duplicate
the underlying parquet rows.
"""
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass
import numpy as np
from iohmm_evac.inference.data import FitData
from iohmm_evac.types import IntArray
__all__ = [
"ResampledFitData",
"index_fit_data",
"resample_indices",
]
@dataclass(frozen=True, slots=True)
class ResampledFitData:
"""A bootstrap-resampled :class:`FitData` plus the index vector that produced it."""
indices: IntArray
"""The household indices used for this resample, shape (N,)."""
data: FitData
"""The :class:`FitData` produced by indexing into the source bundle."""
def resample_indices(n: int, n_resamples: int, seed: int) -> Iterator[IntArray]:
"""Yield ``n_resamples`` index vectors of length ``n`` drawn with replacement.
Reproducible: ``seed=k`` always produces the same sequence of resamples,
independent of ``n_resamples``.
"""
if n <= 0:
msg = f"n must be a positive integer, got {n}"
raise ValueError(msg)
if n_resamples < 0:
msg = f"n_resamples must be non-negative, got {n_resamples}"
raise ValueError(msg)
rng = np.random.default_rng(seed)
for _ in range(n_resamples):
idx = rng.integers(0, n, size=n, dtype=np.int64)
yield np.asarray(idx, dtype=np.int64)
def index_fit_data(data: FitData, indices: IntArray) -> FitData:
"""Return a :class:`FitData` whose households are ``data`` indexed by ``indices``.
``indices`` is a length-``M`` integer vector; rows may repeat. The
resulting :class:`FitData` has ``M`` households and the same time axis.
"""
if indices.ndim != 1:
msg = f"indices must be 1-D, got shape {indices.shape}"
raise ValueError(msg)
if indices.size and (int(indices.min()) < 0 or int(indices.max()) >= data.n):
msg = (
f"indices out of bounds: min={int(indices.min())}, max={int(indices.max())}, n={data.n}"
)
raise ValueError(msg)
sub_inputs = np.ascontiguousarray(data.inputs[indices])
sub_dep = np.ascontiguousarray(data.departure[indices])
sub_disp = np.ascontiguousarray(data.displacement[indices])
sub_comm = np.ascontiguousarray(data.comm[indices])
sub_states = (
np.ascontiguousarray(data.true_states[indices]) if data.true_states is not None else None
)
return FitData(
inputs=sub_inputs,
departure=sub_dep,
displacement=sub_disp,
comm=sub_comm,
true_states=sub_states,
)