# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""State and parameter recovery metrics for IO-HMM fits."""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from iohmm_evac.inference.fit_params import (
ALLOWED_TRANSITIONS,
FEATURE_NAMES,
FitParameters,
K,
learnable_indices,
)
from iohmm_evac.types import FloatArray, IntArray
__all__ = [
"ParameterRecoveryReport",
"align_fit_to_truth",
"parameter_recovery",
"state_recovery_accuracy",
"state_recovery_confusion",
]
def state_recovery_confusion(true_states: IntArray, fit_states_aligned: IntArray) -> FloatArray:
"""Row-normalized ``K x K`` confusion matrix.
``confusion[true_label, fit_label]`` is the share of true-``true_label``
observations that the (aligned) fit labeled as ``fit_label``.
"""
counts = np.zeros((K, K), dtype=np.float64)
np.add.at(counts, (true_states.ravel(), fit_states_aligned.ravel()), 1.0)
row_totals = counts.sum(axis=1, keepdims=True)
row_totals = np.maximum(row_totals, 1.0)
return np.asarray(counts / row_totals, dtype=np.float64)
def state_recovery_accuracy(true_states: IntArray, fit_states_aligned: IntArray) -> float:
"""Plain accuracy: share of (i, t) where the aligned label matches truth."""
if true_states.shape != fit_states_aligned.shape:
msg = f"shape mismatch: true {true_states.shape} vs fit {fit_states_aligned.shape}"
raise ValueError(msg)
return float(np.mean(true_states == fit_states_aligned))
@dataclass(frozen=True, slots=True)
class ParameterRecoveryReport:
"""Side-by-side recovery metrics, organized by parameter group."""
transition_alpha_true: FloatArray
transition_alpha_fit: FloatArray
transition_beta_true: FloatArray
transition_beta_fit: FloatArray
transition_alpha_rmse: float
transition_beta_rmse: float
emission_p_true: FloatArray
emission_p_fit: FloatArray
emission_mu_true: FloatArray
emission_mu_fit: FloatArray
emission_sigma_true: FloatArray
emission_sigma_fit: FloatArray
emission_lambda_true: FloatArray
emission_lambda_fit: FloatArray
emission_p_rmse: float
emission_mu_rmse: float
emission_sigma_rmse: float
emission_lambda_rmse: float
feature_names: tuple[str, ...]
def align_fit_to_truth(fit: FitParameters, perm: IntArray) -> FitParameters:
"""Relabel a fitted :class:`FitParameters` to match the truth's state ordering.
``perm[fit_label] = true_label``.
"""
inv = np.argsort(perm) # inv[true_label] = fit_label
new_alpha = fit.transitions.alpha[inv][:, inv]
new_beta = fit.transitions.beta[inv][:, inv]
new_init_logits = fit.initial.logits[inv]
new_p = fit.emissions.p_departure[inv]
new_mu = fit.emissions.mu_displacement[inv]
new_sigma = fit.emissions.sigma_displacement[inv]
new_lambda = fit.emissions.lambda_comm[inv]
from iohmm_evac.inference.fit_params import (
EmissionFitParams,
InitialFitParams,
TransitionFitParams,
)
return FitParameters(
initial=InitialFitParams(logits=new_init_logits),
transitions=TransitionFitParams(alpha=new_alpha, beta=new_beta),
emissions=EmissionFitParams(
p_departure=new_p,
mu_displacement=new_mu,
sigma_displacement=new_sigma,
lambda_comm=new_lambda,
sigma_floor=fit.emissions.sigma_floor,
),
feature_names=fit.feature_names,
)
def _rmse(a: FloatArray, b: FloatArray, mask: np.ndarray | None = None) -> float:
a_use = a[mask] if mask is not None else a
b_use = b[mask] if mask is not None else b
if a_use.size == 0:
return 0.0
diff = a_use - b_use
return float(np.sqrt(np.mean(diff * diff)))
def parameter_recovery(truth: FitParameters, fit_aligned: FitParameters) -> ParameterRecoveryReport:
"""Compute side-by-side parameter recovery metrics.
Forbidden transition cells (``-inf`` α) are excluded from the alpha and
beta RMSEs since the model never learns them.
"""
learnable, _ = learnable_indices()
finite_alpha_mask = learnable & np.isfinite(truth.transitions.alpha)
alpha_true = truth.transitions.alpha
alpha_fit = fit_aligned.transitions.alpha
alpha_rmse = _rmse(alpha_true, alpha_fit, finite_alpha_mask)
beta_true = truth.transitions.beta
beta_fit = fit_aligned.transitions.beta
beta_mask = np.broadcast_to(learnable[:, :, None], beta_true.shape)
beta_rmse = _rmse(beta_true, beta_fit, beta_mask)
p_rmse = _rmse(truth.emissions.p_departure, fit_aligned.emissions.p_departure)
mu_rmse = _rmse(truth.emissions.mu_displacement, fit_aligned.emissions.mu_displacement)
sigma_rmse = _rmse(truth.emissions.sigma_displacement, fit_aligned.emissions.sigma_displacement)
lambda_rmse = _rmse(truth.emissions.lambda_comm, fit_aligned.emissions.lambda_comm)
# Surface forbidden cells as NaN in the fit copy so the report doesn't
# advertise apparent matches at -inf.
alpha_fit_disp = np.where(ALLOWED_TRANSITIONS, alpha_fit, np.nan)
alpha_true_disp = np.where(ALLOWED_TRANSITIONS, alpha_true, np.nan)
return ParameterRecoveryReport(
transition_alpha_true=alpha_true_disp,
transition_alpha_fit=alpha_fit_disp,
transition_beta_true=beta_true,
transition_beta_fit=beta_fit,
transition_alpha_rmse=alpha_rmse,
transition_beta_rmse=beta_rmse,
emission_p_true=truth.emissions.p_departure,
emission_p_fit=fit_aligned.emissions.p_departure,
emission_mu_true=truth.emissions.mu_displacement,
emission_mu_fit=fit_aligned.emissions.mu_displacement,
emission_sigma_true=truth.emissions.sigma_displacement,
emission_sigma_fit=fit_aligned.emissions.sigma_displacement,
emission_lambda_true=truth.emissions.lambda_comm,
emission_lambda_fit=fit_aligned.emissions.lambda_comm,
emission_p_rmse=p_rmse,
emission_mu_rmse=mu_rmse,
emission_sigma_rmse=sigma_rmse,
emission_lambda_rmse=lambda_rmse,
feature_names=tuple(FEATURE_NAMES),
)