Source code for zvec_db.embedders.sparse.base
"""Base classes for sparse embedding transformers.
This module provides abstract base classes that factor out common logic
for BM25-family transformers (BM25, BM25L, BM25+, etc.).
"""
import logging
from abc import abstractmethod
from typing import Any
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
logger = logging.getLogger(__name__)
[docs]
class BaseBM25Transformer(BaseEstimator, TransformerMixin):
"""Abstract base class for BM25-family transformers.
This class factorizes common logic shared across BM25 variants:
- IDF computation from document frequencies
- Average document length calculation
- Fit/transform boilerplate with validation
Subclasses must implement the :meth:`_compute_scores` method to define
their specific scoring formula.
Attributes:
k1 (float): Term frequency saturation parameter.
idf_ (ndarray): Computed inverse document frequencies for all terms.
avgdl_ (float): Average document length in the training corpus.
is_fitted_ (bool): Whether the transformer has been fitted.
Notes:
Subclasses should only override:
- ``_compute_scores()``: Core scoring formula (required)
- ``__init__()``: To add variant-specific parameters (optional)
"""
[docs]
def __init__(self, k1: float = 1.2):
"""Initialize the base transformer.
Args:
k1 (float): Term frequency saturation parameter.
Typical range: 1.2-2.0. Higher values mean slower saturation.
"""
self.k1 = k1
self.idf_: np.ndarray
self.avgdl_: float
self.is_fitted_ = False
[docs]
def fit(self, X: csr_matrix, y: Any = None) -> "BaseBM25Transformer":
"""Compute IDF values and average document length from a count matrix.
Args:
X (csr_matrix): Sparse count matrix of shape ``(n_docs, n_terms)``.
y: Ignored; present for scikit-learn compatibility.
Returns:
self: The fitted transformer.
Raises:
ValueError: If the corpus is empty (average document length is zero).
"""
n_samples, n_features = X.shape
df = np.diff(X.tocsc().indptr)
self.idf_ = np.log((n_samples - df + 0.5) / (df + 0.5) + 1.0)
self.avgdl_ = X.sum(axis=1).mean()
if self.avgdl_ == 0:
raise ValueError(
"Average document length is zero. "
"This may indicate an empty corpus or all-empty documents."
)
self.is_fitted_ = True
return self
[docs]
def transform(self, X: csr_matrix) -> csr_matrix:
"""Apply BM25 scoring to a count matrix.
Args:
X (csr_matrix): Sparse count matrix of shape ``(n_docs, n_terms)``.
Returns:
csr_matrix: BM25-weighted sparse matrix of the same shape.
Raises:
RuntimeError: If the transformer has not been fitted.
"""
if not self.is_fitted_:
raise RuntimeError(
f"{self.__class__.__name__} must be fitted before transform. "
"Call .fit() first."
)
X = X.tocsr()
len_X = X.sum(axis=1).A1
rows, cols = X.nonzero()
data = X.data
# Compute normalization term (variant-specific)
norm = self._compute_norm(len_X[rows])
# Compute BM25 scores using variant-specific formula
new_data = self._compute_scores(data, norm, cols)
return csr_matrix((new_data, (rows, cols)), shape=X.shape)
@abstractmethod
def _compute_norm(self, doc_lengths: np.ndarray) -> np.ndarray:
"""Compute normalization term for the BM25 variant.
Args:
doc_lengths: Document lengths for each non-zero entry.
Returns:
Normalization term for each entry.
"""
pass
@abstractmethod
def _compute_scores(
self, data: np.ndarray, norm: np.ndarray, cols: np.ndarray
) -> np.ndarray:
"""Compute final BM25 scores using variant-specific formula.
Args:
data: Term frequencies for each non-zero entry.
norm: Pre-computed normalization term for each entry.
cols: Column indices (used to index IDF values).
Returns:
Final BM25 scores for each entry.
"""
pass
def _log_zero_denominator_warning(self, count: int) -> None:
"""Log a warning for zero denominators encountered during transform.
Args:
count: Number of zero denominators detected.
"""
if count > 0:
logger.warning(
"%s: %d zero denominators detected in transform. "
"This may indicate empty documents or extreme parameter values. "
"Using 1e-10 fallback to prevent division by zero.",
self.__class__.__name__,
count,
)
def _safe_divide(
self, numerator: np.ndarray, denominator: np.ndarray
) -> np.ndarray:
"""Safely divide arrays, replacing zero denominators with small epsilon.
Args:
numerator: Numerator array.
denominator: Denominator array.
Returns:
Result of division with zero denominators replaced by 1e-10.
"""
zero_mask = denominator == 0
if np.any(zero_mask):
self._log_zero_denominator_warning(np.sum(zero_mask))
return numerator / np.where(denominator != 0, denominator, 1e-10)