# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""High-level :func:`fit` entry: multi-restart EM and best-result selection."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
import numpy as np
from numpy.random import Generator
from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.em import EMConfig, EMResult, run_em
from iohmm_evac.inference.fit_params import FitParameters
from iohmm_evac.inference.initialization import (
kmeans_init,
random_initialization,
)
__all__ = ["FitResult", "fit"]
@dataclass(frozen=True, slots=True)
class FitResult:
"""Aggregate result of a multi-restart fit."""
best: EMResult
"""The restart with the highest final log-likelihood."""
all_runs: tuple[EMResult, ...]
"""Every restart's outcome, in input order."""
best_index: int
"""Index of ``best`` inside ``all_runs``."""
@property
def params(self) -> FitParameters:
"""Convenience accessor for ``best.params``."""
return self.best.params
def _strategy_factory(
name: str,
data: FitData,
rng: Generator,
truth_init: FitParameters | None,
) -> Callable[[Generator], FitParameters]:
if name == "random":
return lambda r: random_initialization(r)
if name == "kmeans":
return lambda r: kmeans_init(data, r)
if name == "truth":
if truth_init is None:
msg = "init='truth' requires truth_init to be supplied"
raise ValueError(msg)
return lambda _r: truth_init
msg = f"Unknown init strategy: {name!r}"
raise ValueError(msg)
def fit(
data: FitData,
*,
n_restarts: int = 1,
em_config: EMConfig | None = None,
init: str = "random",
rng: Generator | None = None,
truth_init: FitParameters | None = None,
) -> FitResult:
"""Run EM ``n_restarts`` times and return the best fit by log-likelihood."""
if n_restarts < 1:
msg = "n_restarts must be at least 1"
raise ValueError(msg)
rng_local = rng if rng is not None else np.random.default_rng(0)
factory = _strategy_factory(init, data, rng_local, truth_init)
runs: list[EMResult] = []
for _ in range(n_restarts):
seed = int(rng_local.integers(0, 2**31 - 1))
sub_rng = np.random.default_rng(seed)
params0 = factory(sub_rng)
runs.append(run_em(params0, data, em_config))
finals = np.array([r.final_log_likelihood for r in runs], dtype=np.float64)
best_idx = int(np.argmax(finals))
return FitResult(best=runs[best_idx], all_runs=tuple(runs), best_index=best_idx)