# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""State-permutation alignment via the Hungarian algorithm.
Latent state indices in a fitted IO-HMM are arbitrary up to permutation.
To compare a fit to the truth (or to another fit) we need to align the
``K!`` possible labelings — solving the assignment problem on the
``K x K`` co-occurrence (confusion) matrix.
"""
from __future__ import annotations
import numpy as np
from scipy.optimize import linear_sum_assignment
from iohmm_evac.types import IntArray
__all__ = ["align_states", "apply_permutation"]
def align_states(true_states: IntArray, fit_states: IntArray, k: int) -> IntArray:
"""Return a permutation ``perm`` of length ``K``.
``perm[fit_label]`` gives the canonical (true) label that ``fit_label``
should be relabeled to. Solves a Hungarian assignment that maximizes
the total mass on the diagonal of the confusion matrix.
"""
if true_states.shape != fit_states.shape:
msg = f"shape mismatch: true {true_states.shape} vs fit {fit_states.shape}"
raise ValueError(msg)
confusion = np.zeros((k, k), dtype=np.int64)
np.add.at(confusion, (true_states.ravel(), fit_states.ravel()), 1)
# Hungarian minimizes cost; we want to maximize matches, so negate.
row_ind, col_ind = linear_sum_assignment(-confusion)
# row_ind is sorted true labels; col_ind[i] is the fit label assigned to true label row_ind[i].
perm = np.empty(k, dtype=np.int64)
for true_label, fit_label in zip(row_ind, col_ind, strict=True):
perm[int(fit_label)] = int(true_label)
return perm
def apply_permutation(states: IntArray, perm: IntArray) -> IntArray:
"""Relabel ``states`` according to ``perm``: ``out[t] = perm[states[t]]``."""
return np.asarray(perm[states], dtype=np.int64)