src/iohmm_evac/inference/log_space.py

# SPDX-License-Identifier: AGPL-3.0-only
# Copyright (C) 2026 SWGY, Inc.
"""Log-space numerical primitives used by the forward-backward recursions."""

from __future__ import annotations

import numpy as np
from scipy.special import logsumexp as _scipy_logsumexp

from iohmm_evac.types import FloatArray

__all__ = [
    "LOG_EPS",
    "log_softmax",
    "logsumexp",
    "safe_log",
]


LOG_EPS: float = -1e30
"""Sentinel finite value standing in for ``-inf`` in masked log-arrays."""


def logsumexp(a: FloatArray, axis: int | None = None) -> FloatArray:
    """Numerically stable log-sum-exp; thin typed wrapper around scipy."""
    out = _scipy_logsumexp(a, axis=axis)
    return np.asarray(out, dtype=np.float64)


def log_softmax(logits: FloatArray, axis: int = -1) -> FloatArray:
    """Log-softmax along ``axis``."""
    m = np.max(logits, axis=axis, keepdims=True)
    shifted = logits - m
    z = np.log(np.exp(shifted).sum(axis=axis, keepdims=True))
    return np.asarray(shifted - z, dtype=np.float64)


def safe_log(x: FloatArray, floor: float = 1e-300) -> FloatArray:
    """Elementwise log with a positive floor to keep ``-inf`` out of arrays."""
    return np.log(np.clip(x, floor, None))