Source code for zvec_db.embedders.sparse.base

"""Base classes for sparse embedding transformers.

This module provides abstract base classes that factor out common logic
for BM25-family transformers (BM25, BM25L, BM25+, etc.).
"""

import logging
from abc import abstractmethod
from typing import Any

import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin

logger = logging.getLogger(__name__)


[docs] class BaseBM25Transformer(BaseEstimator, TransformerMixin): """Abstract base class for BM25-family transformers. This class factorizes common logic shared across BM25 variants: - IDF computation from document frequencies - Average document length calculation - Fit/transform boilerplate with validation Subclasses must implement the :meth:`_compute_scores` method to define their specific scoring formula. Attributes: k1 (float): Term frequency saturation parameter. idf_ (ndarray): Computed inverse document frequencies for all terms. avgdl_ (float): Average document length in the training corpus. is_fitted_ (bool): Whether the transformer has been fitted. Notes: Subclasses should only override: - ``_compute_scores()``: Core scoring formula (required) - ``__init__()``: To add variant-specific parameters (optional) """
[docs] def __init__(self, k1: float = 1.2): """Initialize the base transformer. Args: k1 (float): Term frequency saturation parameter. Typical range: 1.2-2.0. Higher values mean slower saturation. """ self.k1 = k1 self.idf_: np.ndarray self.avgdl_: float self.is_fitted_ = False
[docs] def fit(self, X: csr_matrix, y: Any = None) -> "BaseBM25Transformer": """Compute IDF values and average document length from a count matrix. Args: X (csr_matrix): Sparse count matrix of shape ``(n_docs, n_terms)``. y: Ignored; present for scikit-learn compatibility. Returns: self: The fitted transformer. Raises: ValueError: If the corpus is empty (average document length is zero). """ n_samples, n_features = X.shape df = np.diff(X.tocsc().indptr) self.idf_ = np.log((n_samples - df + 0.5) / (df + 0.5) + 1.0) self.avgdl_ = X.sum(axis=1).mean() if self.avgdl_ == 0: raise ValueError( "Average document length is zero. " "This may indicate an empty corpus or all-empty documents." ) self.is_fitted_ = True return self
[docs] def transform(self, X: csr_matrix) -> csr_matrix: """Apply BM25 scoring to a count matrix. Args: X (csr_matrix): Sparse count matrix of shape ``(n_docs, n_terms)``. Returns: csr_matrix: BM25-weighted sparse matrix of the same shape. Raises: RuntimeError: If the transformer has not been fitted. """ if not self.is_fitted_: raise RuntimeError( f"{self.__class__.__name__} must be fitted before transform. " "Call .fit() first." ) X = X.tocsr() len_X = X.sum(axis=1).A1 rows, cols = X.nonzero() data = X.data # Compute normalization term (variant-specific) norm = self._compute_norm(len_X[rows]) # Compute BM25 scores using variant-specific formula new_data = self._compute_scores(data, norm, cols) return csr_matrix((new_data, (rows, cols)), shape=X.shape)
@abstractmethod def _compute_norm(self, doc_lengths: np.ndarray) -> np.ndarray: """Compute normalization term for the BM25 variant. Args: doc_lengths: Document lengths for each non-zero entry. Returns: Normalization term for each entry. """ pass @abstractmethod def _compute_scores( self, data: np.ndarray, norm: np.ndarray, cols: np.ndarray ) -> np.ndarray: """Compute final BM25 scores using variant-specific formula. Args: data: Term frequencies for each non-zero entry. norm: Pre-computed normalization term for each entry. cols: Column indices (used to index IDF values). Returns: Final BM25 scores for each entry. """ pass def _log_zero_denominator_warning(self, count: int) -> None: """Log a warning for zero denominators encountered during transform. Args: count: Number of zero denominators detected. """ if count > 0: logger.warning( "%s: %d zero denominators detected in transform. " "This may indicate empty documents or extreme parameter values. " "Using 1e-10 fallback to prevent division by zero.", self.__class__.__name__, count, ) def _safe_divide( self, numerator: np.ndarray, denominator: np.ndarray ) -> np.ndarray: """Safely divide arrays, replacing zero denominators with small epsilon. Args: numerator: Numerator array. denominator: Denominator array. Returns: Result of division with zero denominators replaced by 1e-10. """ zero_mask = denominator == 0 if np.any(zero_mask): self._log_zero_denominator_warning(np.sum(zero_mask)) return numerator / np.where(denominator != 0, denominator, 1e-10)