Source code for zvec_db.rerankers.cross_encoder.sentence_transformer

"""Sentence Transformer binary cross-encoder reranker."""

from __future__ import annotations

import logging
from typing import Any, Mapping, Optional

import numpy as np

from .base import BaseCrossEncoderReranker

logger = logging.getLogger(__name__)


[docs] class SentenceTransformerReranker(BaseCrossEncoderReranker): """Cross-encoder reranker using Sentence Transformers models locally. This reranker uses the CrossEncoder class from sentence-transformers to compute relevance scores between query and document pairs. Unlike API-based cross-encoders, this runs entirely locally on CPU or GPU. SentenceTransformer CrossEncoder models output a single score via sigmoid for binary relevance (relevant/not relevant). Args: query (str): Query for reranking. **Required**. topn (int, optional): Number of top documents to return. Defaults to 10. model_name (str, optional): CrossEncoder model name from HuggingFace. Examples: - "cross-encoder/ms-marco-MiniLM-L-6-v2" (fast, good quality) - "cross-encoder/ms-marco-TinyBERT-L-2-v2" (very fast) - "cross-encoder/stsb-distilroberta-base" (semantic similarity) Defaults to "cross-encoder/ms-marco-MiniLM-L-6-v2". device (Optional[str], optional): Device to run model on. "cpu", "cuda", or None for auto-detect. Defaults to None. max_length (Optional[int], optional): Maximum sequence length. Defaults to 512. rerank_field (Optional[str], optional): Document field to use for scoring. If None, uses the entire document content. Defaults to None. batch_size (int, optional): Batch size for inference. Defaults to 32. show_progress_bar (bool, optional): Show progress bar during inference. Defaults to False. fusion_score_weight (float, optional): Weight for blending cross-encoder scores with fusion scores. Formula: final_score = cross_encoder_score × weight + fusion_score × (1 - weight) - weight = 1.0 → 100% cross-encoder, 0% fusion (default) - weight = 0.8 → 80% cross-encoder, 20% fusion - weight = 0.5 → 50% cross-encoder, 50% fusion - weight = 0.0 → 0% cross-encoder, 100% fusion Defaults to 1.0 (pure cross-encoder score). model_kwargs (Optional[Mapping[str, Any]], optional): Additional keyword arguments passed to CrossEncoder constructor. Useful for options like: - torch_dtype: Model dtype (torch.float16, torch.bfloat16, "auto") - trust_remote_code: Trust remote code from HuggingFace Hub - token: HuggingFace API token for private models - revision: Model revision to load - cache_dir: Custom cache directory - local_files_only: Load only local files - attn_implementation: Attention implementation (e.g., "flash_attention_2") Defaults to None (no additional kwargs). Example: >>> from zvec_db.rerankers.cross_encoder import SentenceTransformerReranker >>> >>> # Binary relevance reranker >>> reranker = SentenceTransformerReranker( ... query="machine learning", ... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", ... topn=10, ... ) >>> >>> results = reranker.rerank({"bm25": bm25_docs}) >>> >>> # Blended scores: 80% cross-encoder + 20% fusion >>> reranker = SentenceTransformerReranker( ... query="machine learning", ... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", ... topn=10, ... fusion_score_weight=0.8, ... ) >>> results = reranker.rerank({"bm25": docs}) >>> >>> # With model_kwargs for private models >>> reranker = SentenceTransformerReranker( ... query="machine learning", ... model_name="org/private-model", ... model_kwargs={"token": "hf_..."}, ... ) >>> results = reranker.rerank({"bm25": docs}) >>> >>> # With model_kwargs for dtype (float16 for reduced memory) >>> import torch >>> reranker = SentenceTransformerReranker( ... query="machine learning", ... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", ... model_kwargs={"torch_dtype": torch.float16}, ... ) >>> results = reranker.rerank({"bm25": docs}) Note: - Requires the `sentence-transformers` package - Models are downloaded automatically on first use - GPU acceleration available if CUDA is installed - Models output scores in [0, 1] via sigmoid See Also: OpenAIReranker: API-based cross-encoder with LLM. """
[docs] def __init__( self, query: str, topn: int = 10, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: Optional[str] = None, max_length: Optional[int] = 512, rerank_field: Optional[str] = None, batch_size: int = 32, show_progress_bar: bool = False, fusion_score_weight: float = 1.0, model_kwargs: Optional[Mapping[str, Any]] = None, ): super().__init__( query=query, topn=topn, rerank_field=rerank_field, fusion_score_weight=fusion_score_weight, ) self._model_name = model_name self._device = device self._max_length = max_length self._batch_size = batch_size self._show_progress_bar = show_progress_bar self._model_kwargs: Mapping[str, Any] = model_kwargs or {} self._model: Optional[Any] = None self._fitted = False
_public_names = ( "model_name", "device", "max_length", "batch_size", "show_progress_bar", "model_kwargs", ) def _load_model(self) -> None: """Load the CrossEncoder model. Raises: ImportError: If the sentence-transformers package is not installed. """ try: from sentence_transformers import CrossEncoder except ImportError as e: raise ImportError( "SentenceTransformerReranker requires the 'sentence-transformers' package. " "Install it with: pip install sentence-transformers" ) from e self._model = CrossEncoder( self._model_name, device=self._device, max_length=self._max_length, **self._model_kwargs, )
[docs] def fit(self, documents: list[str]) -> "SentenceTransformerReranker": """Initialize the reranker by loading the model. For Sentence Transformers CrossEncoder, this loads the model. No training is performed as models are pre-trained. Args: documents: List of documents (not used, for API compatibility). Returns: self: For method chaining. """ if self._model is None: self._load_model() self._fitted = True return self
def _compute_scores_batch( self, query: str, documents: list[str], ) -> list[float]: """Compute relevance scores for a batch of documents. Args: query (str): The search query. documents (list[str]): List of document texts to score. Returns: list[float]: List of relevance scores. """ if not documents: return [] # Auto-load model if not already loaded (lazy loading) if self._model is None: self._load_model() self._fitted = True # Model guaranteed non-None after _load_model if self._model is None: raise RuntimeError( "Model not loaded. This should not happen after calling _load_model()." ) # Compute CrossEncoder scores pairs = [[query, text] for text in documents] scores = self._model.predict( pairs, batch_size=self._batch_size, show_progress_bar=self._show_progress_bar, ) # Convert scores to list of floats if numpy array if isinstance(scores, np.ndarray): scores = scores.tolist() # Ensure scores is a list if not isinstance(scores, list): scores = [float(scores)] return scores