Source code for zvec_db.rerankers.cross_encoder.sentence_transformer
"""Sentence Transformer binary cross-encoder reranker."""
from __future__ import annotations
import logging
from typing import Any, Mapping, Optional
import numpy as np
from .base import BaseCrossEncoderReranker
logger = logging.getLogger(__name__)
[docs]
class SentenceTransformerReranker(BaseCrossEncoderReranker):
"""Cross-encoder reranker using Sentence Transformers models locally.
This reranker uses the CrossEncoder class from sentence-transformers
to compute relevance scores between query and document pairs. Unlike
API-based cross-encoders, this runs entirely locally on CPU or GPU.
SentenceTransformer CrossEncoder models output a single score via sigmoid
for binary relevance (relevant/not relevant).
Args:
query (str): Query for reranking. **Required**.
topn (int, optional): Number of top documents to return. Defaults to 10.
model_name (str, optional): CrossEncoder model name from HuggingFace.
Examples:
- "cross-encoder/ms-marco-MiniLM-L-6-v2" (fast, good quality)
- "cross-encoder/ms-marco-TinyBERT-L-2-v2" (very fast)
- "cross-encoder/stsb-distilroberta-base" (semantic similarity)
Defaults to "cross-encoder/ms-marco-MiniLM-L-6-v2".
device (Optional[str], optional): Device to run model on.
"cpu", "cuda", or None for auto-detect. Defaults to None.
max_length (Optional[int], optional): Maximum sequence length.
Defaults to 512.
rerank_field (Optional[str], optional): Document field to use for scoring.
If None, uses the entire document content. Defaults to None.
batch_size (int, optional): Batch size for inference. Defaults to 32.
show_progress_bar (bool, optional): Show progress bar during inference.
Defaults to False.
fusion_score_weight (float, optional): Weight for blending cross-encoder scores
with fusion scores.
Formula: final_score = cross_encoder_score × weight + fusion_score × (1 - weight)
- weight = 1.0 → 100% cross-encoder, 0% fusion (default)
- weight = 0.8 → 80% cross-encoder, 20% fusion
- weight = 0.5 → 50% cross-encoder, 50% fusion
- weight = 0.0 → 0% cross-encoder, 100% fusion
Defaults to 1.0 (pure cross-encoder score).
model_kwargs (Optional[Mapping[str, Any]], optional): Additional keyword arguments
passed to CrossEncoder constructor. Useful for options like:
- torch_dtype: Model dtype (torch.float16, torch.bfloat16, "auto")
- trust_remote_code: Trust remote code from HuggingFace Hub
- token: HuggingFace API token for private models
- revision: Model revision to load
- cache_dir: Custom cache directory
- local_files_only: Load only local files
- attn_implementation: Attention implementation (e.g., "flash_attention_2")
Defaults to None (no additional kwargs).
Example:
>>> from zvec_db.rerankers.cross_encoder import SentenceTransformerReranker
>>>
>>> # Binary relevance reranker
>>> reranker = SentenceTransformerReranker(
... query="machine learning",
... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
... topn=10,
... )
>>>
>>> results = reranker.rerank({"bm25": bm25_docs})
>>>
>>> # Blended scores: 80% cross-encoder + 20% fusion
>>> reranker = SentenceTransformerReranker(
... query="machine learning",
... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
... topn=10,
... fusion_score_weight=0.8,
... )
>>> results = reranker.rerank({"bm25": docs})
>>>
>>> # With model_kwargs for private models
>>> reranker = SentenceTransformerReranker(
... query="machine learning",
... model_name="org/private-model",
... model_kwargs={"token": "hf_..."},
... )
>>> results = reranker.rerank({"bm25": docs})
>>>
>>> # With model_kwargs for dtype (float16 for reduced memory)
>>> import torch
>>> reranker = SentenceTransformerReranker(
... query="machine learning",
... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
... model_kwargs={"torch_dtype": torch.float16},
... )
>>> results = reranker.rerank({"bm25": docs})
Note:
- Requires the `sentence-transformers` package
- Models are downloaded automatically on first use
- GPU acceleration available if CUDA is installed
- Models output scores in [0, 1] via sigmoid
See Also:
OpenAIReranker: API-based cross-encoder with LLM.
"""
[docs]
def __init__(
self,
query: str,
topn: int = 10,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
device: Optional[str] = None,
max_length: Optional[int] = 512,
rerank_field: Optional[str] = None,
batch_size: int = 32,
show_progress_bar: bool = False,
fusion_score_weight: float = 1.0,
model_kwargs: Optional[Mapping[str, Any]] = None,
):
super().__init__(
query=query,
topn=topn,
rerank_field=rerank_field,
fusion_score_weight=fusion_score_weight,
)
self._model_name = model_name
self._device = device
self._max_length = max_length
self._batch_size = batch_size
self._show_progress_bar = show_progress_bar
self._model_kwargs: Mapping[str, Any] = model_kwargs or {}
self._model: Optional[Any] = None
self._fitted = False
_public_names = (
"model_name",
"device",
"max_length",
"batch_size",
"show_progress_bar",
"model_kwargs",
)
def _load_model(self) -> None:
"""Load the CrossEncoder model.
Raises:
ImportError: If the sentence-transformers package is not installed.
"""
try:
from sentence_transformers import CrossEncoder
except ImportError as e:
raise ImportError(
"SentenceTransformerReranker requires the 'sentence-transformers' package. "
"Install it with: pip install sentence-transformers"
) from e
self._model = CrossEncoder(
self._model_name,
device=self._device,
max_length=self._max_length,
**self._model_kwargs,
)
[docs]
def fit(self, documents: list[str]) -> "SentenceTransformerReranker":
"""Initialize the reranker by loading the model.
For Sentence Transformers CrossEncoder, this loads the model.
No training is performed as models are pre-trained.
Args:
documents: List of documents (not used, for API compatibility).
Returns:
self: For method chaining.
"""
if self._model is None:
self._load_model()
self._fitted = True
return self
def _compute_scores_batch(
self,
query: str,
documents: list[str],
) -> list[float]:
"""Compute relevance scores for a batch of documents.
Args:
query (str): The search query.
documents (list[str]): List of document texts to score.
Returns:
list[float]: List of relevance scores.
"""
if not documents:
return []
# Auto-load model if not already loaded (lazy loading)
if self._model is None:
self._load_model()
self._fitted = True
# Model guaranteed non-None after _load_model
if self._model is None:
raise RuntimeError(
"Model not loaded. This should not happen after calling _load_model()."
)
# Compute CrossEncoder scores
pairs = [[query, text] for text in documents]
scores = self._model.predict(
pairs,
batch_size=self._batch_size,
show_progress_bar=self._show_progress_bar,
)
# Convert scores to list of floats if numpy array
if isinstance(scores, np.ndarray):
scores = scores.tolist()
# Ensure scores is a list
if not isinstance(scores, list):
scores = [float(scores)]
return scores