"""BM25+ sparse embedding with smoothing to prevent zero scores.
This module implements BM25+, an extension of BM25 that adds a smoothing
parameter (delta) to prevent documents with zero term frequency from having
a zero score. This is particularly useful for corpora with many rare terms
or when combining scores from multiple sources.
Classes
-------
BM25PlusTransformer
Scikit-learn transformer implementing BM25+ scoring.
BM25PlusEmbedder
High-level embedder wrapping BM25PlusTransformer with zvec-db compatibility.
Example Usage
-------------
::
from zvec_db.embedders import BM25PlusEmbedder
embedder = BM25PlusEmbedder(
k1=1.2,
b=0.75,
delta=0.5,
max_features=4096
)
embedder.fit(documents)
vector = embedder.embed("search query")
"""
from typing import TYPE_CHECKING, Any, Callable, Optional
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.pipeline import Pipeline
from ..base import BaseSparseEmbedder, ExtendedList
from .base import BaseBM25Transformer
if TYPE_CHECKING:
from zvec_db.preprocessing.config import NormalizationConfig
# Default max_features: 8192 (2^13) provides good vocabulary coverage
# while maintaining memory efficiency. This matches the base class default.
DEFAULT_MAX_FEATURES = 8192
# BM25+ hyperparameters - these are standard values from the literature:
# k1=1.2: Standard value for term frequency saturation (typical range: 1.2-2.0)
# b=0.75: Standard value for length normalization (typical range: 0.5-1.0)
# delta=0.5: Smoothing parameter to prevent zero scores (typical range: 0.4-1.0)
DEFAULT_K1 = 1.2
DEFAULT_B = 0.75
DEFAULT_DELTA = 0.5
[docs]
class BM25PlusEmbedder(BaseSparseEmbedder):
"""Sparse embedder implementing the BM25+ scoring formula.
BM25+ extends BM25 by adding a smoothing parameter (delta) that prevents
zero scores for terms with zero term frequency. This can improve retrieval
performance, especially for corpora with many rare terms.
This class wires together a ``CountVectorizer`` with a ``BM25PlusTransformer``.
Tokenisation behaviour is controlled by the two parameters inherited from
:class:`BaseSparseEmbedder`:
* ``is_pretokenized`` tells the embedder to expect lists of tokens as input
and avoids any preprocessing altogether.
* ``tokenizer`` allows the client to supply a callable that will be
executed on every raw text document *before* vectorisation. When a
tokenizer is used the data passed to the scikit-learn pipeline consists
of token lists as well; the vectorizer is therefore configured to act as
an identity transformer.
The two options are mutually exclusive and validated by the base class.
Args:
tokenizer (Optional[Callable]): Custom tokenizer function. If provided,
it will be called on each document before vectorization.
is_pretokenized (bool): If True, input documents must already be lists
of tokens. Mutually exclusive with ``tokenizer``.
max_features (Optional[int]): Maximum number of features to retain per
document. Defaults to 8192.
k1 (float): Term frequency saturation parameter. Defaults to 1.2.
Typical range: 1.2-2.0. Higher values mean slower saturation.
b (float): Length normalization parameter. Defaults to 0.75.
Typical range: 0.5-1.0. b=1.0 means full length normalization.
delta (float): Smoothing parameter. Defaults to 0.5.
Typical range: 0.4-1.0. Higher values increase the baseline score.
preprocessing_config (Optional[NormalizationConfig]): Configuration for
automatic text preprocessing (normalization, stemming, stopwords).
If set, preprocessing is automatically applied during fit() and embed().
**count_params: Additional keyword arguments passed to
``CountVectorizer`` (e.g., ``min_df``, ``max_df``, ``ngram_range``).
Example:
>>> embedder = BM25PlusEmbedder(k1=1.5, b=0.8, delta=0.6, min_df=2)
>>> embedder.fit(documents)
>>> vectors = embedder.embed(["query text"])
"""
[docs]
def __init__(
self,
tokenizer: Optional[Callable] = None,
is_pretokenized: bool = False,
max_features: Optional[int] = DEFAULT_MAX_FEATURES,
k1: float = DEFAULT_K1,
b: float = DEFAULT_B,
delta: float = DEFAULT_DELTA,
preprocessing_config: Optional["NormalizationConfig"] = None,
**count_params,
):
super().__init__(
tokenizer,
is_pretokenized,
max_features,
preprocessing_config=preprocessing_config,
)
self.k1 = k1
self.b = b
self.delta = delta
self.vectorizer_params = count_params
[docs]
def fit(self, corpus: ExtendedList, y: Any = None) -> "BM25PlusEmbedder":
"""Train the BM25+ pipeline on a corpus of documents.
This method builds a scikit-learn pipeline consisting of:
1. ``CountVectorizer``: Tokenizes documents and builds term counts.
2. ``BM25PlusTransformer``: Applies BM25+ weighting to the count matrix.
The corpus is pre-processed according to the embedder's configuration
(custom tokenizer or pre-tokenized mode) before being passed to the
pipeline.
Args:
corpus (ExtendedList): Training documents. Must be strings unless
``is_pretokenized=True`` or a custom ``tokenizer`` is set.
y: Ignored; present for scikit-learn compatibility.
Returns:
self: The fitted embedder.
Raises:
ValueError: If corpus format doesn't match the configuration.
"""
processed = self._prepare_corpus(corpus)
params = self._prepare_vectorizer_params(self.vectorizer_params)
self.model = Pipeline(
[
("count", CountVectorizer(**params)),
(
"bm25plus",
BM25PlusTransformer(k1=self.k1, b=self.b, delta=self.delta),
),
]
)
self.model.fit(processed)
self.is_fitted_ = True
return self