src/iohmm_evac/diagnostics/recovery.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""State and parameter recovery metrics for IO-HMM fits."""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np

from iohmm_evac.inference.fit_params import (
    ALLOWED_TRANSITIONS,
    FEATURE_NAMES,
    FitParameters,
    K,
    learnable_indices,
)
from iohmm_evac.types import FloatArray, IntArray

__all__ = [
    "ParameterRecoveryReport",
    "align_fit_to_truth",
    "parameter_recovery",
    "state_recovery_accuracy",
    "state_recovery_confusion",
]


def state_recovery_confusion(true_states: IntArray, fit_states_aligned: IntArray) -> FloatArray:
    """Row-normalized ``K x K`` confusion matrix.

    ``confusion[true_label, fit_label]`` is the share of true-``true_label``
    observations that the (aligned) fit labeled as ``fit_label``.
    """
    counts = np.zeros((K, K), dtype=np.float64)
    np.add.at(counts, (true_states.ravel(), fit_states_aligned.ravel()), 1.0)
    row_totals = counts.sum(axis=1, keepdims=True)
    row_totals = np.maximum(row_totals, 1.0)
    return np.asarray(counts / row_totals, dtype=np.float64)


def state_recovery_accuracy(true_states: IntArray, fit_states_aligned: IntArray) -> float:
    """Plain accuracy: share of (i, t) where the aligned label matches truth."""
    if true_states.shape != fit_states_aligned.shape:
        msg = f"shape mismatch: true {true_states.shape} vs fit {fit_states_aligned.shape}"
        raise ValueError(msg)
    return float(np.mean(true_states == fit_states_aligned))


@dataclass(frozen=True, slots=True)
class ParameterRecoveryReport:
    """Side-by-side recovery metrics, organized by parameter group."""

    transition_alpha_true: FloatArray
    transition_alpha_fit: FloatArray
    transition_beta_true: FloatArray
    transition_beta_fit: FloatArray
    transition_alpha_rmse: float
    transition_beta_rmse: float

    emission_p_true: FloatArray
    emission_p_fit: FloatArray
    emission_mu_true: FloatArray
    emission_mu_fit: FloatArray
    emission_sigma_true: FloatArray
    emission_sigma_fit: FloatArray
    emission_lambda_true: FloatArray
    emission_lambda_fit: FloatArray
    emission_p_rmse: float
    emission_mu_rmse: float
    emission_sigma_rmse: float
    emission_lambda_rmse: float

    feature_names: tuple[str, ...]


def align_fit_to_truth(fit: FitParameters, perm: IntArray) -> FitParameters:
    """Relabel a fitted :class:`FitParameters` to match the truth's state ordering.

    ``perm[fit_label] = true_label``.
    """
    inv = np.argsort(perm)  # inv[true_label] = fit_label
    new_alpha = fit.transitions.alpha[inv][:, inv]
    new_beta = fit.transitions.beta[inv][:, inv]
    new_init_logits = fit.initial.logits[inv]
    new_p = fit.emissions.p_departure[inv]
    new_mu = fit.emissions.mu_displacement[inv]
    new_sigma = fit.emissions.sigma_displacement[inv]
    new_lambda = fit.emissions.lambda_comm[inv]

    from iohmm_evac.inference.fit_params import (
        EmissionFitParams,
        InitialFitParams,
        TransitionFitParams,
    )

    return FitParameters(
        initial=InitialFitParams(logits=new_init_logits),
        transitions=TransitionFitParams(alpha=new_alpha, beta=new_beta),
        emissions=EmissionFitParams(
            p_departure=new_p,
            mu_displacement=new_mu,
            sigma_displacement=new_sigma,
            lambda_comm=new_lambda,
            sigma_floor=fit.emissions.sigma_floor,
        ),
        feature_names=fit.feature_names,
    )


def _rmse(a: FloatArray, b: FloatArray, mask: np.ndarray | None = None) -> float:
    a_use = a[mask] if mask is not None else a
    b_use = b[mask] if mask is not None else b
    if a_use.size == 0:
        return 0.0
    diff = a_use - b_use
    return float(np.sqrt(np.mean(diff * diff)))


def parameter_recovery(truth: FitParameters, fit_aligned: FitParameters) -> ParameterRecoveryReport:
    """Compute side-by-side parameter recovery metrics.

    Forbidden transition cells (``-inf`` α) are excluded from the alpha and
    beta RMSEs since the model never learns them.
    """
    learnable, _ = learnable_indices()
    finite_alpha_mask = learnable & np.isfinite(truth.transitions.alpha)
    alpha_true = truth.transitions.alpha
    alpha_fit = fit_aligned.transitions.alpha
    alpha_rmse = _rmse(alpha_true, alpha_fit, finite_alpha_mask)

    beta_true = truth.transitions.beta
    beta_fit = fit_aligned.transitions.beta
    beta_mask = np.broadcast_to(learnable[:, :, None], beta_true.shape)
    beta_rmse = _rmse(beta_true, beta_fit, beta_mask)

    p_rmse = _rmse(truth.emissions.p_departure, fit_aligned.emissions.p_departure)
    mu_rmse = _rmse(truth.emissions.mu_displacement, fit_aligned.emissions.mu_displacement)
    sigma_rmse = _rmse(truth.emissions.sigma_displacement, fit_aligned.emissions.sigma_displacement)
    lambda_rmse = _rmse(truth.emissions.lambda_comm, fit_aligned.emissions.lambda_comm)

    # Surface forbidden cells as NaN in the fit copy so the report doesn't
    # advertise apparent matches at -inf.
    alpha_fit_disp = np.where(ALLOWED_TRANSITIONS, alpha_fit, np.nan)
    alpha_true_disp = np.where(ALLOWED_TRANSITIONS, alpha_true, np.nan)

    return ParameterRecoveryReport(
        transition_alpha_true=alpha_true_disp,
        transition_alpha_fit=alpha_fit_disp,
        transition_beta_true=beta_true,
        transition_beta_fit=beta_fit,
        transition_alpha_rmse=alpha_rmse,
        transition_beta_rmse=beta_rmse,
        emission_p_true=truth.emissions.p_departure,
        emission_p_fit=fit_aligned.emissions.p_departure,
        emission_mu_true=truth.emissions.mu_displacement,
        emission_mu_fit=fit_aligned.emissions.mu_displacement,
        emission_sigma_true=truth.emissions.sigma_displacement,
        emission_sigma_fit=fit_aligned.emissions.sigma_displacement,
        emission_lambda_true=truth.emissions.lambda_comm,
        emission_lambda_fit=fit_aligned.emissions.lambda_comm,
        emission_p_rmse=p_rmse,
        emission_mu_rmse=mu_rmse,
        emission_sigma_rmse=sigma_rmse,
        emission_lambda_rmse=lambda_rmse,
        feature_names=tuple(FEATURE_NAMES),
    )