# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Log-space numerical primitives used by the forward-backward recursions."""
from __future__ import annotations
import numpy as np
from scipy.special import logsumexp as _scipy_logsumexp
from iohmm_evac.types import FloatArray
__all__ = [
"LOG_EPS",
"log_softmax",
"logsumexp",
"safe_log",
]
LOG_EPS: float = -1e30
"""Sentinel finite value standing in for ``-inf`` in masked log-arrays."""
def logsumexp(a: FloatArray, axis: int | None = None) -> FloatArray:
"""Numerically stable log-sum-exp; thin typed wrapper around scipy."""
out = _scipy_logsumexp(a, axis=axis)
return np.asarray(out, dtype=np.float64)
def log_softmax(logits: FloatArray, axis: int = -1) -> FloatArray:
"""Log-softmax along ``axis``."""
m = np.max(logits, axis=axis, keepdims=True)
shifted = logits - m
z = np.log(np.exp(shifted).sum(axis=axis, keepdims=True))
return np.asarray(shifted - z, dtype=np.float64)
def safe_log(x: FloatArray, floor: float = 1e-300) -> FloatArray:
"""Elementwise log with a positive floor to keep ``-inf`` out of arrays."""
return np.log(np.clip(x, floor, None))