# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Parallel bootstrap-fit driver.
Pin BLAS thread counts at module import time so that joblib's outer
parallelism does not fight with numpy/scipy's inner threading. The variables
are set with ``setdefault`` so that explicit user overrides survive.
"""
from __future__ import annotations
import os
# Pin BLAS threads BEFORE numpy is imported in worker processes (loky workers
# import this module fresh). See ``docs/development.md`` for the rationale.
for _var in (
"OMP_NUM_THREADS",
"OPENBLAS_NUM_THREADS",
"MKL_NUM_THREADS",
"BLIS_NUM_THREADS",
"VECLIB_MAXIMUM_THREADS",
"NUMEXPR_NUM_THREADS",
):
os.environ.setdefault(_var, "1")
import tomllib # noqa: E402
from dataclasses import dataclass # noqa: E402
from pathlib import Path # noqa: E402
import numpy as np # noqa: E402
import pyarrow as pa # noqa: E402
import pyarrow.parquet as pq # noqa: E402
import tomli_w # noqa: E402
from joblib import Parallel, delayed # noqa: E402
from iohmm_evac.bootstrap.resample import index_fit_data, resample_indices # noqa: E402
from iohmm_evac.inference.data import FitData # noqa: E402
from iohmm_evac.inference.em import EMConfig, run_em # noqa: E402
from iohmm_evac.inference.fit_params import FitParameters # noqa: E402
from iohmm_evac.inference.initialization import random_initialization # noqa: E402
from iohmm_evac.inference.io import ( # noqa: E402
_params_from_dict,
_params_to_dict,
)
from iohmm_evac.types import IntArray # noqa: E402
__all__ = [
"BootstrapFit",
"load_bootstrap_fits",
"run_bootstrap_fits",
]
@dataclass(frozen=True, slots=True)
class BootstrapFit:
"""Output of a single bootstrap replicate's EM fit."""
replicate_id: int
params: FitParameters
final_log_likelihood: float
iterations: int
converged: bool
indices: IntArray
"""The household-index vector that produced this replicate."""
def _replicate_dir(output_dir: Path, replicate_id: int) -> Path:
return output_dir / f"replicate_{replicate_id:03d}"
def _write_replicate(rep_dir: Path, fit: BootstrapFit) -> None:
rep_dir.mkdir(parents=True, exist_ok=True)
theta_path = rep_dir / "theta.toml"
theta_path.write_bytes(tomli_w.dumps(_params_to_dict(fit.params)).encode("utf-8"))
metadata = {
"replicate_id": int(fit.replicate_id),
"final_log_likelihood": float(fit.final_log_likelihood),
"iterations": int(fit.iterations),
"converged": bool(fit.converged),
"n_resampled": int(fit.indices.shape[0]),
}
(rep_dir / "metadata.toml").write_bytes(tomli_w.dumps(metadata).encode("utf-8"))
indices_table = pa.table(
{"household_index": pa.array(fit.indices.astype(np.int64), type=pa.int64())}
)
pq.write_table(indices_table, rep_dir / "indices.parquet") # type: ignore[no-untyped-call]
def _read_replicate(rep_dir: Path) -> BootstrapFit:
with (rep_dir / "theta.toml").open("rb") as f:
params = _params_from_dict(tomllib.load(f))
with (rep_dir / "metadata.toml").open("rb") as f:
metadata = tomllib.load(f)
indices_table = pq.read_table(rep_dir / "indices.parquet") # type: ignore[no-untyped-call]
indices = np.asarray(indices_table.column("household_index").to_numpy(), dtype=np.int64)
return BootstrapFit(
replicate_id=int(metadata["replicate_id"]),
params=params,
final_log_likelihood=float(metadata["final_log_likelihood"]),
iterations=int(metadata["iterations"]),
converged=bool(metadata["converged"]),
indices=indices,
)
def _fit_one_replicate(
replicate_id: int,
indices: IntArray,
data: FitData,
em_config: EMConfig,
base_seed: int,
warm_start_theta: FitParameters | None,
output_dir: Path,
) -> BootstrapFit:
sub_data = index_fit_data(data, indices)
if warm_start_theta is not None:
params0 = warm_start_theta
else:
rng = np.random.default_rng(base_seed * 100003 + replicate_id)
params0 = random_initialization(rng)
em_result = run_em(params0, sub_data, em_config)
fit = BootstrapFit(
replicate_id=replicate_id,
params=em_result.params,
final_log_likelihood=float(em_result.final_log_likelihood),
iterations=int(em_result.iterations),
converged=bool(em_result.converged),
indices=np.asarray(indices, dtype=np.int64),
)
_write_replicate(_replicate_dir(output_dir, replicate_id), fit)
return fit
def run_bootstrap_fits(
data: FitData,
n_replicates: int,
em_config: EMConfig,
base_seed: int,
n_jobs: int,
output_dir: Path,
warm_start_theta: FitParameters | None = None,
) -> list[BootstrapFit]:
"""Drive ``n_replicates`` parallel EM fits via joblib's loky backend.
Each replicate gets its own household-resample index vector and writes
a ``replicate_NNN/`` directory under ``output_dir`` containing
``theta.toml``, ``metadata.toml``, and ``indices.parquet``.
"""
if n_replicates < 1:
msg = f"n_replicates must be >= 1, got {n_replicates}"
raise ValueError(msg)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
index_seqs = list(resample_indices(data.n, n_replicates, seed=base_seed))
jobs = (
delayed(_fit_one_replicate)(
replicate_id=i,
indices=index_seqs[i],
data=data,
em_config=em_config,
base_seed=base_seed,
warm_start_theta=warm_start_theta,
output_dir=output_dir,
)
for i in range(n_replicates)
)
results: list[BootstrapFit] = list(Parallel(n_jobs=n_jobs, backend="loky")(jobs))
results.sort(key=lambda r: r.replicate_id)
return results
def load_bootstrap_fits(output_dir: Path) -> list[BootstrapFit]:
"""Load every ``replicate_NNN`` subdirectory under ``output_dir``."""
output_dir = Path(output_dir)
if not output_dir.exists():
msg = f"Bootstrap output directory not found: {output_dir}"
raise FileNotFoundError(msg)
rep_dirs = sorted(
p for p in output_dir.iterdir() if p.is_dir() and p.name.startswith("replicate_")
)
if not rep_dirs:
msg = f"No replicate_NNN subdirectories under {output_dir}"
raise FileNotFoundError(msg)
return [_read_replicate(d) for d in rep_dirs]