Source code for zvec_db.embedders.sparse.dismax
"""DisMax (Disjunctive Maximum) sparse embedding for multi-field search.
This module implements the DisMax (Disjunctive Maximum) scoring formula,
which takes the maximum score across multiple queries or fields rather than
summing them. This is useful when you want documents that match at least one
query/field well, rather than documents that match all queries/fields moderately.
Classes
-------
DisMaxTransformer
Scikit-learn transformer implementing DisMax scoring.
DisMaxEmbedder
High-level embedder wrapping DisMaxTransformer with zvec-db compatibility.
Example Usage
-------------
::
from zvec_db.embedders import DisMaxEmbedder
embedder = DisMaxEmbedder(
k1=1.2,
b=0.75,
tie_breaker=0.1,
max_features=4096
)
embedder.fit(documents)
vector = embedder.embed("search query")
"""
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.pipeline import Pipeline
from ..base import BaseSparseEmbedder, ExtendedList, SparseVector
if TYPE_CHECKING:
from zvec_db.preprocessing.config import NormalizationConfig
# Default max_features: 8192 (2^13) provides good vocabulary coverage
# while maintaining memory efficiency. This matches the base class default.
DEFAULT_MAX_FEATURES = 8192
# DisMax hyperparameters
DEFAULT_K1 = 1.2
DEFAULT_B = 0.75
DEFAULT_TIE_BREAKER = 0.0
[docs]
class DisMaxTransformer(BaseEstimator, TransformerMixin):
"""Custom scikit-learn transformer implementing the DisMax scoring formula.
DisMax (Disjunctive Maximum) is a scoring function that takes the maximum
score across multiple queries or fields, rather than summing them. This is
useful when you want documents that match at least one query/field well,
rather than documents that match all queries/fields moderately.
The DisMax score for a document is computed as:
.. math::
\\text{DisMax}(d) = \\max_{q \\in Q}(\\text{score}_q(d)) +
t \\times \\sum_{q \\in Q \\setminus \\{\\text{argmax}\\}} \\text{score}_q(d)
where:
- :math:`Q` is the set of queries or fields
- :math:`t` is the tie breaker parameter (0.0 = pure max, 1.0 = sum)
When ``tie_breaker=0.0``, only the maximum score is used (pure DisMax).
When ``tie_breaker=1.0``, all scores are summed (equivalent to standard fusion).
Intermediate values provide a blend of both approaches.
This transformer is typically used in multi-query scenarios where:
* Different queries target different aspects of relevance.
* You want to avoid penalizing documents that match only one query well.
* Summing scores would unfairly advantage documents matching multiple queries.
Args:
k1 (float): Term frequency saturation parameter. Controls how quickly
term frequency saturates. Higher values mean slower saturation.
Typical range: 1.2 to 2.0. Defaults to 1.2.
b (float): Length normalization parameter. Controls the influence of
document length. ``b=1.0`` means full length normalization,
``b=0.0`` disables it. Defaults to 0.75.
tie_breaker (float): Tie breaker parameter. When multiple queries match,
adds a fraction of non-maximum scores to the maximum.
0.0 = use only maximum score, 1.0 = sum all scores.
Typical range: 0.0 to 0.5. Defaults to 0.0 (pure DisMax).
Attributes:
idf_ (ndarray): Computed inverse document frequencies for all terms.
avgdl_ (float): Average document length in the training corpus.
is_fitted_ (bool): Whether the transformer has been fitted.
Example:
>>> from sklearn.feature_extraction.text import CountVectorizer
>>> from sklearn.pipeline import Pipeline
>>> pipeline = Pipeline([
... ("count", CountVectorizer()),
... ("dismax", DisMaxTransformer(k1=1.5, tie_breaker=0.1))
... ])
>>> pipeline.fit(documents)
"""
[docs]
def __init__(
self,
k1: float = DEFAULT_K1,
b: float = DEFAULT_B,
tie_breaker: float = DEFAULT_TIE_BREAKER,
):
"""Initialize the DisMax transformer.
Args:
k1 (float): Term frequency saturation parameter. Defaults to 1.2.
Typical range: 1.2-2.0. Higher values mean slower saturation.
b (float): Length normalization parameter. Defaults to 0.75.
Typical range: 0.5-1.0. b=1.0 means full length normalization.
tie_breaker (float): Tie breaker parameter. Defaults to 0.0.
Typical range: 0.0-0.5. 0.0 = pure max, 1.0 = sum.
"""
self.k1 = k1
self.b = b
self.tie_breaker = tie_breaker
self.idf_: np.ndarray
self.avgdl_: float
self.is_fitted_ = False
[docs]
def fit(self, X: csr_matrix, y: Any = None) -> "DisMaxTransformer":
"""Compute IDF values and average document length from a count matrix.
Args:
X (csr_matrix): Sparse count matrix of shape ``(n_docs, n_terms)``.
y: Ignored; present for scikit-learn compatibility.
Returns:
self: The fitted transformer.
Raises:
ValueError: If the corpus is empty (average document length is zero).
"""
n_samples, n_features = X.shape
df = np.diff(X.tocsc().indptr)
self.idf_ = np.log((n_samples - df + 0.5) / (df + 0.5) + 1.0)
self.avgdl_ = X.sum(axis=1).mean()
if self.avgdl_ == 0:
raise ValueError(
"Average document length is zero. "
"This may indicate an empty corpus or all-empty documents."
)
self.is_fitted_ = True
return self
[docs]
def transform(self, X: csr_matrix) -> csr_matrix:
"""Apply DisMax scoring to a count matrix.
For each document, computes the BM25 score for each term and takes
the maximum across terms, with optional tie breaking from other terms.
Args:
X (csr_matrix): Sparse count matrix of shape ``(n_docs, n_terms)``.
Returns:
csr_matrix: DisMax-weighted sparse matrix of shape ``(n_docs, 1)``,
where each row contains the DisMax score for that document.
Raises:
RuntimeError: If the transformer has not been fitted.
"""
if not self.is_fitted_:
raise RuntimeError(
"DisMaxTransformer must be fitted before transform. Call .fit() first."
)
X = X.tocsr()
len_X = X.sum(axis=1).A1
# Note: avgdl_ is validated in fit() to prevent division by zero
norm = self.k1 * (1.0 - self.b + self.b * len_X / self.avgdl_)
scores = np.zeros(X.shape[0])
for i in range(X.shape[0]):
row = X.getrow(i)
if len(row.data) == 0:
continue
# Compute BM25 score for each term in this document
# Prevent division by zero: data + norm can be 0 if both are 0
denominator = row.data + norm[i]
term_scores = (
self.idf_[row.indices]
* row.data
* (self.k1 + 1.0)
/ np.where(denominator != 0, denominator, 1e-10)
)
if len(term_scores) == 0:
continue
# DisMax: take maximum score
max_idx = np.argmax(term_scores)
max_score = term_scores[max_idx]
# Add tie breaker: fraction of other scores
if self.tie_breaker > 0 and len(term_scores) > 1:
other_scores = np.delete(term_scores, max_idx)
max_score += self.tie_breaker * np.sum(other_scores)
scores[i] = max_score
# Return as column vector
return csr_matrix(scores.reshape(-1, 1))
[docs]
class DisMaxEmbedder(BaseSparseEmbedder):
"""Sparse embedder implementing the DisMax scoring formula.
DisMax (Disjunctive Maximum) takes the maximum score across multiple terms
or fields, rather than summing them. This is useful when you want documents
that match at least one term well, rather than documents that match all
terms moderately.
The DisMax score formula is:
.. math::
\\text{DisMax}(d) = \\max_{t \\in T}(\\text{score}_t(d)) +
t \\times \\sum_{t \\in T \\setminus \\{\\text{argmax}\\}} \\text{score}_t(d)
where :math:`t` is the tie breaker parameter.
This embedder is particularly useful for:
* Multi-field search (title, content, tags) where matching any field well
should rank highly.
* Disjunctive queries where documents matching any query term should be
retrieved.
* Avoiding score inflation from documents matching many terms weakly.
Args:
tokenizer (Optional[Callable]): Custom tokenizer function. If provided,
it will be called on each document before vectorization.
is_pretokenized (bool): If True, input documents must already be lists
of tokens. Mutually exclusive with ``tokenizer``.
max_features (Optional[int]): Maximum number of features to retain per
document. Defaults to 8192.
k1 (float): Term frequency saturation parameter. Defaults to 1.2.
Typical range: 1.2-2.0. Higher values mean slower saturation.
b (float): Length normalization parameter. Defaults to 0.75.
Typical range: 0.5-1.0. b=1.0 means full length normalization.
tie_breaker (float): Tie breaker parameter. Defaults to 0.0.
0.0 = pure maximum, 1.0 = sum all scores.
preprocessing_config (Optional[NormalizationConfig]): Configuration for
automatic text preprocessing (normalization, stemming, stopwords).
If set, preprocessing is automatically applied during fit() and embed().
**count_params: Additional keyword arguments passed to
``CountVectorizer`` (e.g., ``min_df``, ``max_df``, ``ngram_range``).
Example:
>>> embedder = DisMaxEmbedder(k1=1.5, tie_breaker=0.1, min_df=2)
>>> embedder.fit(documents)
>>> vectors = embedder.embed(["query text"])
"""
[docs]
def __init__(
self,
tokenizer: Optional[Callable] = None,
is_pretokenized: bool = False,
max_features: Optional[int] = DEFAULT_MAX_FEATURES,
k1: float = DEFAULT_K1,
b: float = DEFAULT_B,
tie_breaker: float = DEFAULT_TIE_BREAKER,
preprocessing_config: Optional["NormalizationConfig"] = None,
**count_params,
):
super().__init__(
tokenizer,
is_pretokenized,
max_features,
preprocessing_config=preprocessing_config,
)
self.k1 = k1
self.b = b
self.tie_breaker = tie_breaker
self.vectorizer_params = count_params
[docs]
def fit(self, corpus: ExtendedList, y: Any = None) -> "DisMaxEmbedder":
"""Train the DisMax pipeline on a corpus of documents.
This method builds a scikit-learn pipeline consisting of:
1. ``CountVectorizer``: Tokenizes documents and builds term counts.
2. ``DisMaxTransformer``: Applies DisMax weighting to the count matrix.
The corpus is pre-processed according to the embedder's configuration
(custom tokenizer or pre-tokenized mode) before being passed to the
pipeline.
Args:
corpus (ExtendedList): Training documents. Must be strings unless
``is_pretokenized=True`` or a custom ``tokenizer`` is set.
y: Ignored; present for scikit-learn compatibility.
Returns:
self: The fitted embedder.
Raises:
ValueError: If corpus format doesn't match the configuration.
"""
processed = self._prepare_corpus(corpus)
params = self._prepare_vectorizer_params(self.vectorizer_params)
self.model = Pipeline(
[
("count", CountVectorizer(**params)),
(
"dismax",
DisMaxTransformer(
k1=self.k1, b=self.b, tie_breaker=self.tie_breaker
),
),
]
)
self.model.fit(processed)
self.is_fitted_ = True
return self
[docs]
def embed(
self, input_text: Union[str, List[str], List[List[str]]]
) -> Union[SparseVector, List[SparseVector]]:
"""Embed text into sparse vectors with DisMax scores.
Unlike other embedders that return a vector with multiple non-zero
entries, DisMaxEmbedder returns a single score per document (the
maximum term score).
Args:
input_text: Single document or batch of documents.
Returns:
Union[SparseVector, List[SparseVector]]: For each document, returns
a dictionary with a single entry {0: dismax_score} representing
the maximum term score.
"""
if self.model is None:
raise RuntimeError("Model must be fitted or loaded before embedding.")
is_single, processed_input = self.preprocess_input(input_text)
sparse_matrix = self.model.transform(processed_input)
# Convert to list of single-entry dictionaries
results = []
for i in range(sparse_matrix.shape[0]):
row = sparse_matrix.getrow(i)
if len(row.data) > 0:
results.append({0: float(row.data[0])})
else:
results.append({})
return results[0] if is_single else results