Source code for zvec_db.rerankers.cross_encoder.classification

"""Multi-class classification reranking using HuggingFace transformers."""

from __future__ import annotations

import logging
from typing import Any, Mapping, Optional

from .base import BaseCrossEncoderReranker

logger = logging.getLogger(__name__)


[docs] class ClassificationReranker(BaseCrossEncoderReranker): """Multi-class classification reranker using HuggingFace transformers. This reranker uses a multi-class classification model from HuggingFace (via the transformers library) and computes the expected value of the class distribution: .. math:: E[\\text{score}] = \\frac{\\sum_{i} prob_i \\times i}{num\\_classes - 1} The model outputs logits for each class (0, 1, 2, ..., num_classes-1). Softmax is applied to get probabilities, then expected value is computed and normalized to [0, 1]. Args: query (str): Query for reranking. **Required**. topn (int, optional): Number of top documents to return. Defaults to 10. model_name (str, optional): Classification model name from HuggingFace. Should be a model fine-tuned for text classification with multiple labels. Examples: "cross-encoder/ms-marco-MiniLM-L-6-v2" (binary), "nboost/pt-bert-base-uncased-msmarco" (binary), or any model with config.num_labels set. device (Optional[str], optional): Device to run model on. "cpu", "cuda", or None for auto-detect. Defaults to None. max_length (Optional[int], optional): Maximum sequence length. Defaults to 512. num_classes (Optional[int], optional): Number of classes for classification. If None, will be inferred from model.config.num_labels. For binary: 2 (classes 0 and 1) For multi-class: e.g., 5 for 0-4 relevance scale. Defaults to None (auto-infer). rerank_field (Optional[str], optional): Document field to use for scoring. If None, uses the entire document content. Defaults to None. batch_size (int, optional): Batch size for inference. Defaults to 32. show_progress_bar (bool, optional): Show progress bar during inference. Defaults to False. fusion_score_weight (float, optional): Weight for blending cross-encoder scores with fusion scores. Formula: final_score = cross_encoder_score × weight + fusion_score × (1 - weight) - weight = 1.0 → 100% cross-encoder, 0% fusion (default) - weight = 0.8 → 80% cross-encoder, 20% fusion - weight = 0.5 → 50% cross-encoder, 50% fusion - weight = 0.0 → 0% cross-encoder, 100% fusion Defaults to 1.0 (pure cross-encoder score). model_kwargs (Optional[Mapping[str, Any]], optional): Additional keyword arguments passed to AutoModelForSequenceClassification and AutoTokenizer. Useful for options like: - torch_dtype: Model dtype (torch.float16, torch.bfloat16, "auto" for auto-detection) - trust_remote_code: Trust remote code from HuggingFace Hub - token: HuggingFace API token for private models - revision: Model revision to load - cache_dir: Custom cache directory - local_files_only: Load only local files - attn_implementation: Attention implementation (e.g., "flash_attention_2", "sdpa") - load_in_8bit: Enable 8-bit quantization (requires bitsandbytes) - load_in_4bit: Enable 4-bit quantization (requires bitsandbytes) - device_map: Device mapping for distributed loading (e.g., "auto", "balanced") Defaults to None (no additional kwargs). Example: >>> from zvec_db.rerankers.cross_encoder import ClassificationReranker >>> >>> # Binary classification (num_classes inferred from model) >>> reranker = ClassificationReranker( ... query="machine learning", ... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", ... topn=10, ... ) >>> >>> # Multi-level relevance with explicit num_classes >>> reranker = ClassificationReranker( ... query="machine learning", ... model_name="your-multi-class-classifier", ... num_classes=5, ... topn=10, ... ) >>> >>> reranker.fit([]) # Load model >>> results = reranker.rerank({"bm25": docs}) >>> >>> # With model_kwargs for private models or custom options >>> reranker = ClassificationReranker( ... query="machine learning", ... model_name="org/private-model", ... model_kwargs={"token": "hf_...", "trust_remote_code": True}, ... ) >>> reranker.fit([]) >>> results = reranker.rerank({"bm25": docs}) >>> >>> # With model_kwargs for dtype (float16 for reduced memory) >>> import torch >>> reranker = ClassificationReranker( ... query="machine learning", ... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", ... model_kwargs={"torch_dtype": torch.float16}, ... ) >>> reranker.fit([]) >>> results = reranker.rerank({"bm25": docs}) >>> >>> # With model_kwargs for 8-bit quantization (requires bitsandbytes) >>> reranker = ClassificationReranker( ... query="machine learning", ... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", ... model_kwargs={"load_in_8bit": True}, ... ) >>> reranker.fit([]) >>> results = reranker.rerank({"bm25": docs}) Note: - Requires the `transformers` and `torch` packages - Model must be trained/fine-tuned for multi-class text classification - num_classes is inferred from model.config.num_labels if not provided - GPU acceleration available if CUDA is installed - Scores are normalized to [0, 1] via expected value See Also: OpenAIDecoderReranker: API-based classification with LLM logprobs. """
[docs] def __init__( self, query: str, topn: int = 10, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: Optional[str] = None, max_length: Optional[int] = 512, num_classes: Optional[int] = None, rerank_field: Optional[str] = None, batch_size: int = 32, show_progress_bar: bool = False, fusion_score_weight: float = 1.0, model_kwargs: Optional[Mapping[str, Any]] = None, ): super().__init__( query=query, topn=topn, rerank_field=rerank_field, fusion_score_weight=fusion_score_weight, ) self._model_name = model_name self._device = device self._max_length = max_length self._batch_size = batch_size self._show_progress_bar = show_progress_bar self._num_classes = num_classes # None means auto-infer self._model_kwargs: Mapping[str, Any] = model_kwargs or {} self._model: Optional[Any] = None self._tokenizer: Optional[Any] = None self._fitted = False if num_classes is not None and num_classes <= 0: raise ValueError(f"num_classes must be positive or None, got {num_classes}")
_public_names = ( "model_name", "device", "max_length", "num_classes", "batch_size", "show_progress_bar", "model_kwargs", ) def _load_model(self) -> None: """Load the classification model from HuggingFace. Raises: OSError: If the model cannot be loaded from HuggingFace. ImportError: If the transformers or torch packages are not installed. ValueError: If the model is not a classification model (no num_labels config). """ try: import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer except ImportError as e: raise ImportError( "ClassificationReranker requires the 'transformers' and 'torch' packages. " "Install it with: pip install transformers torch" ) from e try: # Load tokenizer self._tokenizer = AutoTokenizer.from_pretrained( self._model_name, **self._model_kwargs ) # First, load the model config to get num_labels from transformers import AutoConfig config = AutoConfig.from_pretrained(self._model_name, **self._model_kwargs) # Infer num_classes from model config if not provided if self._num_classes is None: if not hasattr(config, "num_labels") or config.num_labels is None: raise ValueError( f"Model '{self._model_name}' does not have num_labels configured. " "Please specify num_classes explicitly or use a classification model." ) self._num_classes = config.num_labels # Load model - don't override num_labels, use model's native value self._model = AutoModelForSequenceClassification.from_pretrained( self._model_name, **self._model_kwargs, ) # Move to device if self._device: self._model.to(self._device) else: # Auto-detect: use CUDA if available device = "cuda" if torch.cuda.is_available() else "cpu" self._model.to(device) self._model.eval() except OSError as e: raise OSError( f"Failed to load model '{self._model_name}' from HuggingFace. " "Ensure the model name is correct and you have internet connection." ) from e
[docs] def fit(self, documents: list[str]) -> "ClassificationReranker": """Initialize the reranker by loading the model. Args: documents: List of documents (not used, for API compatibility). Returns: self: For method chaining. """ if self._model is None: self._load_model() self._fitted = True return self
def _compute_scores_batch( self, query: str, documents: list[str], ) -> list[float]: """Compute expected value scores for a batch of documents. Args: query (str): The search query. documents (list[str]): List of document texts to score. Returns: list[float]: List of expected value scores normalized to [0, 1]. """ if not documents: return [] # Auto-load model if not already loaded (lazy loading) if self._model is None: self._load_model() self._fitted = True import numpy as np import torch # Tokenize all pairs (model guaranteed non-None after _load_model) if self._tokenizer is None: raise RuntimeError("Tokenizer not loaded. Call _load_model() first.") if self._model is None: raise RuntimeError("Model not loaded. Call _load_model() first.") pairs = [[query, text] for text in documents] inputs = self._tokenizer( pairs, padding=True, truncation=True, max_length=self._max_length, return_tensors="pt", ) # Move inputs to device inputs = {k: v.to(self._model.device) for k, v in inputs.items()} # Forward pass try: with torch.no_grad(): outputs = self._model(**inputs) logits = outputs.logits # Apply softmax to get probabilities probs = torch.softmax(logits, dim=-1).cpu().numpy() # Validate _num_classes before use if self._num_classes is None: raise RuntimeError( "num_classes not set. Call _infer_num_classes() or set num_classes explicitly." ) # probs shape: (n_docs, n_classes) # Compute expected value: E = sum(prob_i * i) expected_values = np.sum(probs * np.arange(self._num_classes), axis=1) # Normalize to [0, 1] by dividing by (num_classes - 1) # If num_classes = 1: use prob[0] directly as score (cross-encoder style) if self._num_classes > 1: expected_values = expected_values / (self._num_classes - 1) else: # Single class: prob[0] is the relevance score (like a cross-encoder) expected_values = probs[:, 0] return expected_values.tolist() except torch.cuda.OutOfMemoryError as e: raise MemoryError( f"Out of memory during inference. " f"Try reducing batch_size (current: {self._batch_size})." ) from e except (RuntimeError, ValueError) as e: raise RuntimeError(f"Model inference failed: {e}") from e