src/iohmm_evac/inference/fit.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""High-level :func:`fit` entry: multi-restart EM and best-result selection."""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass

import numpy as np
from numpy.random import Generator

from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.em import EMConfig, EMResult, run_em
from iohmm_evac.inference.fit_params import FitParameters
from iohmm_evac.inference.initialization import (
    kmeans_init,
    random_initialization,
)

__all__ = ["FitResult", "fit"]


@dataclass(frozen=True, slots=True)
class FitResult:
    """Aggregate result of a multi-restart fit."""

    best: EMResult
    """The restart with the highest final log-likelihood."""
    all_runs: tuple[EMResult, ...]
    """Every restart's outcome, in input order."""
    best_index: int
    """Index of ``best`` inside ``all_runs``."""

    @property
    def params(self) -> FitParameters:
        """Convenience accessor for ``best.params``."""
        return self.best.params


def _strategy_factory(
    name: str,
    data: FitData,
    rng: Generator,
    truth_init: FitParameters | None,
) -> Callable[[Generator], FitParameters]:
    if name == "random":
        return lambda r: random_initialization(r)
    if name == "kmeans":
        return lambda r: kmeans_init(data, r)
    if name == "truth":
        if truth_init is None:
            msg = "init='truth' requires truth_init to be supplied"
            raise ValueError(msg)
        return lambda _r: truth_init
    msg = f"Unknown init strategy: {name!r}"
    raise ValueError(msg)


def fit(
    data: FitData,
    *,
    n_restarts: int = 1,
    em_config: EMConfig | None = None,
    init: str = "random",
    rng: Generator | None = None,
    truth_init: FitParameters | None = None,
) -> FitResult:
    """Run EM ``n_restarts`` times and return the best fit by log-likelihood."""
    if n_restarts < 1:
        msg = "n_restarts must be at least 1"
        raise ValueError(msg)
    rng_local = rng if rng is not None else np.random.default_rng(0)
    factory = _strategy_factory(init, data, rng_local, truth_init)
    runs: list[EMResult] = []
    for _ in range(n_restarts):
        seed = int(rng_local.integers(0, 2**31 - 1))
        sub_rng = np.random.default_rng(seed)
        params0 = factory(sub_rng)
        runs.append(run_em(params0, data, em_config))

    finals = np.array([r.final_log_likelihood for r in runs], dtype=np.float64)
    best_idx = int(np.argmax(finals))
    return FitResult(best=runs[best_idx], all_runs=tuple(runs), best_index=best_idx)