# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Recovery-diagnostic plots: confusion matrix, parameter scatter, LL trace.
Same conventions as :mod:`iohmm_evac.report.plots`: each function takes an
optional :class:`matplotlib.axes.Axes`, returns the Axes, and never calls
``plt.show()`` or ``fig.savefig()`` itself.
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import numpy as np
from iohmm_evac.diagnostics.recovery import ParameterRecoveryReport
from iohmm_evac.inference.fit_params import learnable_indices
from iohmm_evac.report.constants import STATE_ORDER
from iohmm_evac.types import FloatArray
if TYPE_CHECKING:
from matplotlib.axes import Axes
__all__ = [
"plot_log_likelihood_trace",
"plot_parameter_recovery",
"plot_state_recovery_confusion",
]
def plot_state_recovery_confusion(confusion: FloatArray, ax: Axes | None = None) -> Axes:
"""Heatmap of the K×K row-normalized confusion matrix.
Each cell shows the share of true-row mass landing in fit-column.
Darker cells = more mass.
"""
if ax is None:
_, ax = plt.subplots(figsize=(5, 5))
k = confusion.shape[0]
im = ax.imshow(confusion, cmap="Blues", vmin=0.0, vmax=1.0, aspect="equal")
ax.set_xticks(range(k))
ax.set_yticks(range(k))
if k == len(STATE_ORDER):
ax.set_xticklabels(list(STATE_ORDER))
ax.set_yticklabels(list(STATE_ORDER))
ax.set_xlabel("Fit (aligned)")
ax.set_ylabel("Truth")
ax.set_title("State recovery confusion")
for i in range(k):
for j in range(k):
value = float(confusion[i, j])
color = "white" if value > 0.5 else "black"
ax.text(j, i, f"{value:.2f}", ha="center", va="center", color=color, fontsize=8)
if ax.figure is not None:
ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
return ax
def _scatter_alpha(report: ParameterRecoveryReport, ax: Axes) -> None:
learnable, _ = learnable_indices()
truth = report.transition_alpha_true
fit = report.transition_alpha_fit
xs: list[float] = []
ys: list[float] = []
labels: list[str] = []
for k in range(truth.shape[0]):
for j in range(truth.shape[1]):
if not learnable[k, j]:
continue
t_val = truth[k, j]
f_val = fit[k, j]
if not (np.isfinite(t_val) and np.isfinite(f_val)):
continue
xs.append(float(t_val))
ys.append(float(f_val))
labels.append(f"α[{STATE_ORDER[k]}→{STATE_ORDER[j]}]")
ax.scatter(xs, ys, marker="o", color="#1f77b4", s=40, label="α (transition intercept)")
for x_, y_, lab in zip(xs, ys, labels, strict=True):
ax.annotate(lab, (x_, y_), fontsize=6, alpha=0.6, xytext=(3, 3), textcoords="offset points")
def _scatter_beta(report: ParameterRecoveryReport, ax: Axes) -> None:
learnable, _ = learnable_indices()
truth = report.transition_beta_true
fit = report.transition_beta_fit
xs: list[float] = []
ys: list[float] = []
for k in range(truth.shape[0]):
for j in range(truth.shape[1]):
if not learnable[k, j]:
continue
for f_idx in range(truth.shape[2]):
t_val = float(truth[k, j, f_idx])
f_val = float(fit[k, j, f_idx])
if not (np.isfinite(t_val) and np.isfinite(f_val)):
continue
xs.append(t_val)
ys.append(f_val)
if xs:
ax.scatter(xs, ys, marker="x", color="#d62728", s=20, label="β (slope)", alpha=0.7)
def _scatter_emissions(report: ParameterRecoveryReport, ax: Axes) -> None:
pairs = [
(report.emission_p_true, report.emission_p_fit, "p", "#2ca02c", "s"),
(report.emission_mu_true, report.emission_mu_fit, "μ", "#9467bd", "^"),
(
report.emission_sigma_true,
report.emission_sigma_fit,
"σ",
"#8c564b",
"v",
),
(
report.emission_lambda_true,
report.emission_lambda_fit,
"λ",
"#e377c2",
"P",
),
]
for true, fit, label, color, marker in pairs:
ax.scatter(true, fit, marker=marker, color=color, s=40, label=f"{label} (emission)")
def plot_parameter_recovery(report: ParameterRecoveryReport, ax: Axes | None = None) -> Axes:
"""Scatter true vs estimated parameter values, all groups on one axis."""
if ax is None:
_, ax = plt.subplots(figsize=(7, 7))
_scatter_alpha(report, ax)
_scatter_beta(report, ax)
_scatter_emissions(report, ax)
# Identity line spanning whatever the data range happens to be.
xlim = ax.get_xlim()
ylim = ax.get_ylim()
lo = min(xlim[0], ylim[0])
hi = max(xlim[1], ylim[1])
ax.plot([lo, hi], [lo, hi], color="black", linestyle="--", alpha=0.4, label="identity")
ax.set_xlabel("Truth")
ax.set_ylabel("Fit (aligned)")
ax.set_title("Parameter recovery")
ax.legend(loc="upper left", fontsize=7, framealpha=0.9)
return ax
def _draw_one_ll_panel(
ax: Axes,
traces: Sequence[Sequence[float]],
*,
best_index: int | None,
skip_first_iter: bool,
title: str,
) -> None:
"""Render a single LL-trace panel, optionally dropping iteration 1."""
for i, trace in enumerate(traces):
if skip_first_iter:
if len(trace) < 2:
continue
x_offset = 2
values = list(trace[1:])
else:
x_offset = 1
values = list(trace)
if not values:
continue
x = np.arange(x_offset, x_offset + len(values))
if i == best_index:
ax.plot(x, values, color="#1f77b4", lw=2.5, label=f"restart {i} (best)")
else:
ax.plot(x, values, color="#aaaaaa", lw=1.0, alpha=0.7, label=f"restart {i}")
ax.set_xlabel("EM iteration")
ax.set_ylabel("Log-likelihood")
ax.set_title(title)
ax.legend(loc="lower right", fontsize=7, framealpha=0.9)
def plot_log_likelihood_trace(
traces: Sequence[Sequence[float]],
ax: Sequence[Axes] | None = None,
*,
best_index: int | None = None,
) -> Sequence[Axes]:
"""Two-panel log-likelihood traces over EM iterations.
Left panel: every iteration, full y-range — shows the initial-to-final
jump that random init typically produces. Right panel: iteration 2
onward only, auto-scaled — the part where convergence behavior is
actually visible.
Y-units are *total* log-likelihood across all (i, t) observations; the
figure annotates this so a reader doesn't have to guess at the
magnitude. Restarts that converged in a single iteration are skipped
in the right panel.
If ``ax`` is supplied it must be a length-2 sequence; otherwise a
single figure with two side-by-side subplots is created.
"""
if ax is None:
fig, axes_arr = plt.subplots(1, 2, figsize=(12, 4))
axes: list[Axes] = list(axes_arr)
fig.suptitle("EM log-likelihood traces (units: total LL across all (i,t))", fontsize=10)
else:
axes = list(ax)
if len(axes) != 2:
msg = f"plot_log_likelihood_trace needs exactly 2 axes, got {len(axes)}"
raise ValueError(msg)
if not traces:
for a in axes:
a.set_title("Log-likelihood trace (no restarts)")
return axes
finals = [t[-1] for t in traces if len(t) > 0]
if best_index is None and finals:
best_index = int(np.argmax(finals))
_draw_one_ll_panel(
axes[0], traces, best_index=best_index, skip_first_iter=False, title="All iterations"
)
_draw_one_ll_panel(
axes[1], traces, best_index=best_index, skip_first_iter=True, title="From iter 2"
)
return axes