"""BM25 sparse embedding using scikit-learn pipelines.
This module implements the BM25 (Best Matching 25) scoring formula, a
probabilistic ranking function widely used in information retrieval.
BM25 improves upon simple term frequency by accounting for document
length normalization and term saturation.
Classes
-------
BM25Transformer
Scikit-learn transformer implementing BM25 scoring.
BM25Embedder
High-level embedder wrapping BM25Transformer with zvec-db compatibility.
Example Usage
-------------
::
from zvec_db.embedders import BM25Embedder
embedder = BM25Embedder(
k1=1.2,
b=0.75,
max_features=4096
)
embedder.fit(documents)
vector = embedder.embed("search query")
"""
from typing import TYPE_CHECKING, Any, Callable, Optional
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.pipeline import Pipeline
from ..base import BaseSparseEmbedder, ExtendedList
from .base import BaseBM25Transformer
if TYPE_CHECKING:
from zvec_db.preprocessing.config import NormalizationConfig
# Default max_features: 8192 (2^13) provides good vocabulary coverage
# while maintaining memory efficiency. This matches the base class default.
DEFAULT_MAX_FEATURES = 8192
# BM25 hyperparameters - these are standard values from the literature:
# k1=1.2: Standard value for term frequency saturation (typical range: 1.2-2.0)
# b=0.75: Standard value for length normalization (typical range: 0.5-1.0)
DEFAULT_K1 = 1.2
DEFAULT_B = 0.75
[docs]
class BM25Embedder(BaseSparseEmbedder):
"""Sparse embedder implementing the BM25 scoring formula.
This class wires together a ``CountVectorizer`` with a lightweight
``BM25Transformer``. Tokenisation behaviour is controlled by the two
parameters inherited from :class:`BaseSparseEmbedder`:
* ``is_pretokenized`` tells the embedder to expect lists of tokens as input
and avoids any preprocessing altogether.
* ``tokenizer`` allows the client to supply a callable that will be
executed on every raw text document *before* vectorisation. When a
tokenizer is used the data passed to the scikit-learn pipeline consists
of token lists as well; the vectorizer is therefore configured to act as
an identity transformer.
The two options are mutually exclusive and validated by the base class.
Args:
tokenizer (Optional[Callable]): Custom tokenizer function.
is_pretokenized (bool): If True, input documents must be lists of tokens.
max_features (Optional[int]): Maximum number of features to retain.
k1 (float): Term frequency saturation parameter. Defaults to 1.2.
b (float): Length normalization parameter. Defaults to 0.75.
preprocessing_config (Optional[NormalizationConfig]): Configuration for
automatic text preprocessing (normalization, stemming, stopwords).
If set, preprocessing is automatically applied during fit() and embed().
**count_params: Additional parameters for CountVectorizer.
"""
[docs]
def __init__(
self,
tokenizer: Optional[Callable] = None,
is_pretokenized: bool = False,
max_features: Optional[int] = DEFAULT_MAX_FEATURES,
k1: float = DEFAULT_K1,
b: float = DEFAULT_B,
preprocessing_config: Optional["NormalizationConfig"] = None,
**count_params,
):
super().__init__(
tokenizer,
is_pretokenized,
max_features,
preprocessing_config=preprocessing_config,
)
self.k1 = k1
self.b = b
self.vectorizer_params = count_params
[docs]
def fit(self, corpus: ExtendedList, y: Any = None) -> "BM25Embedder":
"""Train the BM25 pipeline on a corpus of documents.
This method builds a scikit-learn pipeline consisting of:
1. ``CountVectorizer``: Tokenizes documents and builds term counts.
2. ``BM25Transformer``: Applies BM25 weighting to the count matrix.
The corpus is pre-processed according to the embedder's configuration
(custom tokenizer or pre-tokenized mode) before being passed to the
pipeline.
Args:
corpus (ExtendedList): Training documents. Must be strings unless
``is_pretokenized=True`` or a custom ``tokenizer`` is set.
y: Ignored; present for scikit-learn compatibility.
Returns:
self: The fitted embedder.
Raises:
ValueError: If corpus format doesn't match the configuration.
"""
processed = self._prepare_corpus(corpus)
params = self._prepare_vectorizer_params(self.vectorizer_params)
self.model = Pipeline(
[
("count", CountVectorizer(**params)),
("bm25", BM25Transformer(k1=self.k1, b=self.b)),
]
)
self.model.fit(processed)
self.is_fitted_ = True
return self