src/iohmm_evac/inference/initialization.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Initialization strategies for the IO-HMM parameter dataclasses."""

from __future__ import annotations

import numpy as np
from numpy.random import Generator

from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import (
    ALLOWED_TRANSITIONS,
    EmissionFitParams,
    F,
    FitParameters,
    InitialFitParams,
    K,
    TransitionFitParams,
    dgp_truth_to_fit_init,
)
from iohmm_evac.params import EmissionParams as DGPEmissionParams
from iohmm_evac.params import PopulationParams as DGPPopulationParams
from iohmm_evac.params import TransitionParams as DGPTransitionParams
from iohmm_evac.types import State

__all__ = [
    "from_dgp_truth",
    "kmeans_init",
    "random_initialization",
]


def from_dgp_truth(
    transitions: DGPTransitionParams,
    emissions: DGPEmissionParams,
    population: DGPPopulationParams | None = None,
) -> FitParameters:
    """Initialize at the DGP's true projected parameters."""
    return dgp_truth_to_fit_init(transitions, emissions, population)


def random_initialization(
    rng: Generator,
    *,
    alpha_scale: float = 1.0,
    beta_scale: float = 0.3,
    sigma: float = 1.0,
) -> FitParameters:
    """Randomly perturb a generic prior to seed an EM restart.

    The transition matrix starts at a slight bias toward staying in place
    (self-loops at logit 0; non-self learnable entries at moderately negative
    values plus jitter). Emissions are seeded at heuristic location/scales.
    """
    alpha = np.full((K, K), -np.inf, dtype=np.float64)
    beta = np.zeros((K, K, F), dtype=np.float64)
    np.fill_diagonal(alpha, 0.0)
    for k in range(K):
        for j in range(K):
            if k == j or not ALLOWED_TRANSITIONS[k, j]:
                continue
            alpha[k, j] = -3.0 + alpha_scale * rng.normal(0.0, 1.0)
            beta[k, j] = beta_scale * rng.normal(0.0, 1.0, size=F)

    init_logits = np.full(K, -2.0, dtype=np.float64)
    init_logits[int(State.UA)] = 0.0
    init_logits = init_logits + 0.1 * rng.normal(size=K)

    p_departure = np.clip(0.05 + 0.05 * rng.normal(size=K), 1e-3, 0.999)
    p_departure[int(State.ER)] = 0.9 + 0.05 * rng.standard_normal()
    p_departure = np.clip(p_departure, 1e-3, 1 - 1e-3)

    mu = np.array([0.5, 0.5, 0.5, 20.0, 60.0], dtype=np.float64) + sigma * rng.normal(size=K)
    sig = np.full(K, max(sigma, 0.5), dtype=np.float64)
    lam = np.array([0.5, 1.5, 3.5, 2.5, 1.0], dtype=np.float64) + 0.3 * rng.normal(size=K)
    lam = np.maximum(lam, 1e-3)

    emit = EmissionFitParams(
        p_departure=np.asarray(p_departure, dtype=np.float64),
        mu_displacement=np.asarray(mu, dtype=np.float64),
        sigma_displacement=np.asarray(sig, dtype=np.float64),
        lambda_comm=np.asarray(lam, dtype=np.float64),
    )
    trans = TransitionFitParams(alpha=alpha, beta=beta)
    return FitParameters(
        initial=InitialFitParams(logits=np.asarray(init_logits, dtype=np.float64)),
        transitions=trans,
        emissions=emit,
    )


def kmeans_init(data: FitData, rng: Generator, *, n_iter: int = 20) -> FitParameters:
    """Seed emission means via mini K-means on (departure, displacement, comm).

    Transitions and the initial distribution are seeded from
    :func:`random_initialization`; only the emission means are replaced.
    """
    base = random_initialization(rng)
    feats = np.stack(
        [
            data.departure.reshape(-1),
            data.displacement.reshape(-1),
            data.comm.reshape(-1),
        ],
        axis=1,
    )
    n_total = feats.shape[0]
    if n_total < K:
        return base

    # Mini K-means with K initial centers drawn at random.
    idx = rng.choice(n_total, size=K, replace=False)
    centers = feats[idx].copy()
    for _ in range(n_iter):
        dists = np.linalg.norm(feats[:, None, :] - centers[None, :, :], axis=2)
        labels = np.argmin(dists, axis=1)
        new_centers = centers.copy()
        for k in range(K):
            mask = labels == k
            if mask.any():
                new_centers[k] = feats[mask].mean(axis=0)
        if np.allclose(new_centers, centers, atol=1e-6):
            centers = new_centers
            break
        centers = new_centers

    # Map cluster centers (sorted by displacement) back to canonical state order.
    order = np.argsort(centers[:, 1])
    centers = centers[order]
    p = np.clip(centers[:, 0], 1e-3, 1 - 1e-3)
    mu = centers[:, 1]
    lam = np.maximum(centers[:, 2], 1e-3)

    emit = EmissionFitParams(
        p_departure=np.asarray(p, dtype=np.float64),
        mu_displacement=np.asarray(mu, dtype=np.float64),
        sigma_displacement=base.emissions.sigma_displacement.copy(),
        lambda_comm=np.asarray(lam, dtype=np.float64),
    )
    return FitParameters(
        initial=base.initial,
        transitions=base.transitions,
        emissions=emit,
        feature_names=base.feature_names,
    )