src/iohmm_evac/bootstrap/runner.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Parallel bootstrap-fit driver.

Pin BLAS thread counts at module import time so that joblib's outer
parallelism does not fight with numpy/scipy's inner threading. The variables
are set with ``setdefault`` so that explicit user overrides survive.
"""

from __future__ import annotations

import os

# Pin BLAS threads BEFORE numpy is imported in worker processes (loky workers
# import this module fresh). See ``docs/development.md`` for the rationale.
for _var in (
    "OMP_NUM_THREADS",
    "OPENBLAS_NUM_THREADS",
    "MKL_NUM_THREADS",
    "BLIS_NUM_THREADS",
    "VECLIB_MAXIMUM_THREADS",
    "NUMEXPR_NUM_THREADS",
):
    os.environ.setdefault(_var, "1")

import tomllib  # noqa: E402
from dataclasses import dataclass  # noqa: E402
from pathlib import Path  # noqa: E402

import numpy as np  # noqa: E402
import pyarrow as pa  # noqa: E402
import pyarrow.parquet as pq  # noqa: E402
import tomli_w  # noqa: E402
from joblib import Parallel, delayed  # noqa: E402

from iohmm_evac.bootstrap.resample import index_fit_data, resample_indices  # noqa: E402
from iohmm_evac.inference.data import FitData  # noqa: E402
from iohmm_evac.inference.em import EMConfig, run_em  # noqa: E402
from iohmm_evac.inference.fit_params import FitParameters  # noqa: E402
from iohmm_evac.inference.initialization import random_initialization  # noqa: E402
from iohmm_evac.inference.io import (  # noqa: E402
    _params_from_dict,
    _params_to_dict,
)
from iohmm_evac.types import IntArray  # noqa: E402

__all__ = [
    "BootstrapFit",
    "load_bootstrap_fits",
    "run_bootstrap_fits",
]


@dataclass(frozen=True, slots=True)
class BootstrapFit:
    """Output of a single bootstrap replicate's EM fit."""

    replicate_id: int
    params: FitParameters
    final_log_likelihood: float
    iterations: int
    converged: bool
    indices: IntArray
    """The household-index vector that produced this replicate."""


def _replicate_dir(output_dir: Path, replicate_id: int) -> Path:
    return output_dir / f"replicate_{replicate_id:03d}"


def _write_replicate(rep_dir: Path, fit: BootstrapFit) -> None:
    rep_dir.mkdir(parents=True, exist_ok=True)
    theta_path = rep_dir / "theta.toml"
    theta_path.write_bytes(tomli_w.dumps(_params_to_dict(fit.params)).encode("utf-8"))

    metadata = {
        "replicate_id": int(fit.replicate_id),
        "final_log_likelihood": float(fit.final_log_likelihood),
        "iterations": int(fit.iterations),
        "converged": bool(fit.converged),
        "n_resampled": int(fit.indices.shape[0]),
    }
    (rep_dir / "metadata.toml").write_bytes(tomli_w.dumps(metadata).encode("utf-8"))

    indices_table = pa.table(
        {"household_index": pa.array(fit.indices.astype(np.int64), type=pa.int64())}
    )
    pq.write_table(indices_table, rep_dir / "indices.parquet")  # type: ignore[no-untyped-call]


def _read_replicate(rep_dir: Path) -> BootstrapFit:
    with (rep_dir / "theta.toml").open("rb") as f:
        params = _params_from_dict(tomllib.load(f))
    with (rep_dir / "metadata.toml").open("rb") as f:
        metadata = tomllib.load(f)
    indices_table = pq.read_table(rep_dir / "indices.parquet")  # type: ignore[no-untyped-call]
    indices = np.asarray(indices_table.column("household_index").to_numpy(), dtype=np.int64)
    return BootstrapFit(
        replicate_id=int(metadata["replicate_id"]),
        params=params,
        final_log_likelihood=float(metadata["final_log_likelihood"]),
        iterations=int(metadata["iterations"]),
        converged=bool(metadata["converged"]),
        indices=indices,
    )


def _fit_one_replicate(
    replicate_id: int,
    indices: IntArray,
    data: FitData,
    em_config: EMConfig,
    base_seed: int,
    warm_start_theta: FitParameters | None,
    output_dir: Path,
) -> BootstrapFit:
    sub_data = index_fit_data(data, indices)
    if warm_start_theta is not None:
        params0 = warm_start_theta
    else:
        rng = np.random.default_rng(base_seed * 100003 + replicate_id)
        params0 = random_initialization(rng)
    em_result = run_em(params0, sub_data, em_config)

    fit = BootstrapFit(
        replicate_id=replicate_id,
        params=em_result.params,
        final_log_likelihood=float(em_result.final_log_likelihood),
        iterations=int(em_result.iterations),
        converged=bool(em_result.converged),
        indices=np.asarray(indices, dtype=np.int64),
    )
    _write_replicate(_replicate_dir(output_dir, replicate_id), fit)
    return fit


def run_bootstrap_fits(
    data: FitData,
    n_replicates: int,
    em_config: EMConfig,
    base_seed: int,
    n_jobs: int,
    output_dir: Path,
    warm_start_theta: FitParameters | None = None,
) -> list[BootstrapFit]:
    """Drive ``n_replicates`` parallel EM fits via joblib's loky backend.

    Each replicate gets its own household-resample index vector and writes
    a ``replicate_NNN/`` directory under ``output_dir`` containing
    ``theta.toml``, ``metadata.toml``, and ``indices.parquet``.
    """
    if n_replicates < 1:
        msg = f"n_replicates must be >= 1, got {n_replicates}"
        raise ValueError(msg)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    index_seqs = list(resample_indices(data.n, n_replicates, seed=base_seed))

    jobs = (
        delayed(_fit_one_replicate)(
            replicate_id=i,
            indices=index_seqs[i],
            data=data,
            em_config=em_config,
            base_seed=base_seed,
            warm_start_theta=warm_start_theta,
            output_dir=output_dir,
        )
        for i in range(n_replicates)
    )
    results: list[BootstrapFit] = list(Parallel(n_jobs=n_jobs, backend="loky")(jobs))
    results.sort(key=lambda r: r.replicate_id)
    return results


def load_bootstrap_fits(output_dir: Path) -> list[BootstrapFit]:
    """Load every ``replicate_NNN`` subdirectory under ``output_dir``."""
    output_dir = Path(output_dir)
    if not output_dir.exists():
        msg = f"Bootstrap output directory not found: {output_dir}"
        raise FileNotFoundError(msg)
    rep_dirs = sorted(
        p for p in output_dir.iterdir() if p.is_dir() and p.name.startswith("replicate_")
    )
    if not rep_dirs:
        msg = f"No replicate_NNN subdirectories under {output_dir}"
        raise FileNotFoundError(msg)
    return [_read_replicate(d) for d in rep_dirs]