src/iohmm_evac/inference/m_step.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""M-step solvers: closed-form for π/emissions, L-BFGS-B for transitions."""

from __future__ import annotations

import numpy as np
from scipy.optimize import minimize

from iohmm_evac.inference.data import FitData
from iohmm_evac.inference.fit_params import (
    EmissionFitParams,
    FitParameters,
    InitialFitParams,
    TransitionFitParams,
)
from iohmm_evac.types import BoolArray, FloatArray

__all__ = [
    "m_step",
    "transition_neg_q_and_grad",
    "update_emissions",
    "update_initial",
    "update_transitions",
]


def update_initial(log_gamma: FloatArray) -> InitialFitParams:
    """Closed-form M-step for the initial distribution."""
    pi_at_zero = np.exp(log_gamma[:, 0, :])  # (N, K)
    expected = pi_at_zero.sum(axis=0)  # (K,)
    expected = np.maximum(expected, 1e-12)
    probs = expected / expected.sum()
    logits = np.log(probs)
    return InitialFitParams(logits=np.asarray(logits, dtype=np.float64))


def update_emissions(
    log_gamma: FloatArray, data: FitData, sigma_floor: float = 1e-2
) -> EmissionFitParams:
    """Closed-form M-step for Bernoulli/Gaussian/Poisson emission parameters."""
    gamma = np.exp(log_gamma)  # (N, T+1, K)
    weight_sum = gamma.sum(axis=(0, 1))  # (K,)
    weight_sum_safe = np.maximum(weight_sum, 1e-12)

    departure = data.departure
    displacement = data.displacement
    comm = data.comm

    p_departure = (gamma * departure[:, :, None]).sum(axis=(0, 1)) / weight_sum_safe
    p_departure = np.clip(p_departure, 1e-6, 1 - 1e-6)

    mu = (gamma * displacement[:, :, None]).sum(axis=(0, 1)) / weight_sum_safe
    diff_sq = (displacement[:, :, None] - mu[None, None, :]) ** 2
    var = (gamma * diff_sq).sum(axis=(0, 1)) / weight_sum_safe
    sigma = np.sqrt(np.maximum(var, sigma_floor * sigma_floor))

    lam = (gamma * comm[:, :, None]).sum(axis=(0, 1)) / weight_sum_safe
    lam = np.maximum(lam, 1e-6)

    return EmissionFitParams(
        p_departure=np.asarray(p_departure, dtype=np.float64),
        mu_displacement=np.asarray(mu, dtype=np.float64),
        sigma_displacement=np.asarray(sigma, dtype=np.float64),
        lambda_comm=np.asarray(lam, dtype=np.float64),
        sigma_floor=sigma_floor,
    )


def _flatten_origin_params(
    alpha_row: FloatArray, beta_row: FloatArray, learnable_j: BoolArray
) -> FloatArray:
    """Flatten the learnable subset of one origin row into a 1-D parameter vector.

    Layout: ``[alpha_j1, alpha_j2, ..., beta_j1_f1, beta_j1_f2, ..., beta_j2_f1, ...]``
    where the ``j*`` are the learnable destination indices in ascending order.
    """
    js = np.flatnonzero(learnable_j)
    f = beta_row.shape[1]
    n_dest = js.shape[0]
    out = np.empty(n_dest + n_dest * f, dtype=np.float64)
    out[:n_dest] = alpha_row[js]
    out[n_dest:] = beta_row[js].reshape(-1)
    return out


def transition_neg_q_and_grad(
    x: FloatArray,
    *,
    origin_k: int,
    learnable_j: BoolArray,
    xi_origin: FloatArray,
    gamma_origin: FloatArray,
    inputs_steps: FloatArray,
) -> tuple[float, FloatArray]:
    """Negative weighted multinomial-logit objective for one origin row.

    Returns ``(-Q_k, -grad_k)`` so the result can be passed straight to
    :func:`scipy.optimize.minimize` with ``jac=True``.

    Performance: the destination axis is treated as ``[self_loop, *learnable]``
    only (forbidden cells are dropped — their ``xi`` is zero in finite
    precision so they contribute nothing to either ``Q`` or its gradient).
    The β gradient is a single ``inputs.T @ diff`` matmul rather than the
    naive ``(N*T, F, n_dest)`` outer-product sum.
    """
    js = np.flatnonzero(learnable_j)
    f = inputs_steps.shape[1]
    n_dest = js.shape[0]
    alpha_learn = x[:n_dest]
    beta_learn = x[n_dest:].reshape(n_dest, f)

    if n_dest == 0:
        # Absorbing row: A_kk = 1 deterministically; Q = 0 and grad = 0.
        return 0.0, np.zeros_like(x)

    learnable_logits = alpha_learn[None, :] + inputs_steps @ beta_learn.T  # (N_steps, n_dest)
    # Compose with the self-loop logit (= 0) and log-softmax across the
    # joined ``[self, *learnable]`` axis only; forbidden destinations
    # contribute exp(-inf) = 0 to the normalizer.
    m = np.maximum(learnable_logits.max(axis=1), 0.0)  # (N_steps,)
    exp_self = np.exp(-m)
    exp_learn = np.exp(learnable_logits - m[:, None])  # (N_steps, n_dest)
    z = exp_self + exp_learn.sum(axis=1)  # (N_steps,)
    log_z = m + np.log(z)  # (N_steps,)
    log_a_learn = learnable_logits - log_z[:, None]  # (N_steps, n_dest)
    log_a_self = -log_z  # (N_steps,)
    a_learn = np.exp(log_a_learn)  # (N_steps, n_dest)

    xi_self = xi_origin[:, origin_k]
    xi_learn = xi_origin[:, js]  # (N_steps, n_dest)

    q = float(np.sum(xi_self * log_a_self) + np.sum(xi_learn * log_a_learn))

    # diff_d = xi[d] - gamma * A_d for each learnable destination d
    diff_learn = xi_learn - gamma_origin[:, None] * a_learn  # (N_steps, n_dest)
    grad_alpha = diff_learn.sum(axis=0)  # (n_dest,)
    grad_beta = inputs_steps.T @ diff_learn  # (F, n_dest)
    grad = np.empty(n_dest + n_dest * f, dtype=np.float64)
    grad[:n_dest] = grad_alpha
    grad[n_dest:] = grad_beta.T.reshape(-1)
    return -q, -grad


def update_transitions(
    log_xi: FloatArray,
    log_gamma: FloatArray,
    inputs: FloatArray,
    current: TransitionFitParams,
    *,
    maxiter: int = 50,
    tol: float = 1e-6,
) -> TransitionFitParams:
    """L-BFGS-B M-step for the transition row of every non-absorbing origin."""
    xi = np.exp(log_xi)  # (N, T, K, K)
    gamma = np.exp(log_gamma)  # (N, T+1, K)

    n, t_total, k_states, _ = xi.shape
    f = inputs.shape[2]
    allowed_mask = np.isfinite(current.alpha)
    self_mask = np.eye(k_states, dtype=bool)
    learnable_mask = allowed_mask & ~self_mask

    inputs_steps = inputs[:, 1:, :].reshape(n * t_total, f)
    xi_flat = xi.reshape(n * t_total, k_states, k_states)
    gamma_flat = gamma[:, :-1, :].reshape(n * t_total, k_states)

    new_alpha = current.alpha.copy()
    new_beta = current.beta.copy()
    np.fill_diagonal(new_alpha, 0.0)

    for k in range(k_states):
        learnable_j = learnable_mask[k]
        if not learnable_j.any():
            continue  # absorbing: nothing to optimize
        x0 = _flatten_origin_params(current.alpha[k], current.beta[k], learnable_j)

        def objective(
            xx: FloatArray, k_local: int = k, lj: BoolArray = learnable_j
        ) -> tuple[float, FloatArray]:
            return transition_neg_q_and_grad(
                xx,
                origin_k=k_local,
                learnable_j=lj,
                xi_origin=xi_flat[:, k_local, :],
                gamma_origin=gamma_flat[:, k_local],
                inputs_steps=inputs_steps,
            )

        result = minimize(
            objective,
            x0,
            jac=True,
            method="L-BFGS-B",
            options={"maxiter": maxiter, "gtol": tol},
        )
        x_opt = np.asarray(result.x, dtype=np.float64)
        js = np.flatnonzero(learnable_j)
        n_dest = js.shape[0]
        new_alpha[k, js] = x_opt[:n_dest]
        new_beta[k, js] = x_opt[n_dest:].reshape(n_dest, f)
        forbidden = ~allowed_mask[k]
        new_alpha[k, forbidden] = -np.inf
        new_beta[k, forbidden] = 0.0
        new_alpha[k, k] = 0.0
        new_beta[k, k] = 0.0

    return TransitionFitParams(alpha=new_alpha, beta=new_beta)


def m_step(
    log_gamma: FloatArray,
    log_xi: FloatArray,
    data: FitData,
    current: FitParameters,
    *,
    sigma_floor: float = 1e-2,
    transition_maxiter: int = 50,
    transition_tol: float = 1e-6,
) -> FitParameters:
    """One full M-step: closed-form initial+emissions, L-BFGS for transitions."""
    initial = update_initial(log_gamma)
    emissions = update_emissions(log_gamma, data, sigma_floor=sigma_floor)
    transitions = update_transitions(
        log_xi,
        log_gamma,
        data.inputs,
        current.transitions,
        maxiter=transition_maxiter,
        tol=transition_tol,
    )
    return FitParameters(
        initial=initial,
        transitions=transitions,
        emissions=emissions,
        feature_names=current.feature_names,
    )