"""Multi-class classification reranking using HuggingFace transformers."""
from __future__ import annotations
import logging
from typing import Any, Mapping, Optional
from .base import BaseCrossEncoderReranker
logger = logging.getLogger(__name__)
[docs]
class ClassificationReranker(BaseCrossEncoderReranker):
"""Multi-class classification reranker using HuggingFace transformers.
This reranker uses a multi-class classification model from HuggingFace
(via the transformers library) and computes the expected value of the
class distribution:
.. math::
E[\\text{score}] = \\frac{\\sum_{i} prob_i \\times i}{num\\_classes - 1}
The model outputs logits for each class (0, 1, 2, ..., num_classes-1).
Softmax is applied to get probabilities, then expected value is computed
and normalized to [0, 1].
Args:
query (str): Query for reranking. **Required**.
topn (int, optional): Number of top documents to return. Defaults to 10.
model_name (str, optional): Classification model name from HuggingFace.
Should be a model fine-tuned for text classification with multiple labels.
Examples: "cross-encoder/ms-marco-MiniLM-L-6-v2" (binary),
"nboost/pt-bert-base-uncased-msmarco" (binary),
or any model with config.num_labels set.
device (Optional[str], optional): Device to run model on.
"cpu", "cuda", or None for auto-detect. Defaults to None.
max_length (Optional[int], optional): Maximum sequence length.
Defaults to 512.
num_classes (Optional[int], optional): Number of classes for classification.
If None, will be inferred from model.config.num_labels.
For binary: 2 (classes 0 and 1)
For multi-class: e.g., 5 for 0-4 relevance scale.
Defaults to None (auto-infer).
rerank_field (Optional[str], optional): Document field to use for scoring.
If None, uses the entire document content. Defaults to None.
batch_size (int, optional): Batch size for inference. Defaults to 32.
show_progress_bar (bool, optional): Show progress bar during inference.
Defaults to False.
fusion_score_weight (float, optional): Weight for blending cross-encoder scores
with fusion scores.
Formula: final_score = cross_encoder_score × weight + fusion_score × (1 - weight)
- weight = 1.0 → 100% cross-encoder, 0% fusion (default)
- weight = 0.8 → 80% cross-encoder, 20% fusion
- weight = 0.5 → 50% cross-encoder, 50% fusion
- weight = 0.0 → 0% cross-encoder, 100% fusion
Defaults to 1.0 (pure cross-encoder score).
model_kwargs (Optional[Mapping[str, Any]], optional): Additional keyword arguments
passed to AutoModelForSequenceClassification and AutoTokenizer. Useful for options like:
- torch_dtype: Model dtype (torch.float16, torch.bfloat16, "auto" for auto-detection)
- trust_remote_code: Trust remote code from HuggingFace Hub
- token: HuggingFace API token for private models
- revision: Model revision to load
- cache_dir: Custom cache directory
- local_files_only: Load only local files
- attn_implementation: Attention implementation (e.g., "flash_attention_2", "sdpa")
- load_in_8bit: Enable 8-bit quantization (requires bitsandbytes)
- load_in_4bit: Enable 4-bit quantization (requires bitsandbytes)
- device_map: Device mapping for distributed loading (e.g., "auto", "balanced")
Defaults to None (no additional kwargs).
Example:
>>> from zvec_db.rerankers.cross_encoder import ClassificationReranker
>>>
>>> # Binary classification (num_classes inferred from model)
>>> reranker = ClassificationReranker(
... query="machine learning",
... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
... topn=10,
... )
>>>
>>> # Multi-level relevance with explicit num_classes
>>> reranker = ClassificationReranker(
... query="machine learning",
... model_name="your-multi-class-classifier",
... num_classes=5,
... topn=10,
... )
>>>
>>> reranker.fit([]) # Load model
>>> results = reranker.rerank({"bm25": docs})
>>>
>>> # With model_kwargs for private models or custom options
>>> reranker = ClassificationReranker(
... query="machine learning",
... model_name="org/private-model",
... model_kwargs={"token": "hf_...", "trust_remote_code": True},
... )
>>> reranker.fit([])
>>> results = reranker.rerank({"bm25": docs})
>>>
>>> # With model_kwargs for dtype (float16 for reduced memory)
>>> import torch
>>> reranker = ClassificationReranker(
... query="machine learning",
... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
... model_kwargs={"torch_dtype": torch.float16},
... )
>>> reranker.fit([])
>>> results = reranker.rerank({"bm25": docs})
>>>
>>> # With model_kwargs for 8-bit quantization (requires bitsandbytes)
>>> reranker = ClassificationReranker(
... query="machine learning",
... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
... model_kwargs={"load_in_8bit": True},
... )
>>> reranker.fit([])
>>> results = reranker.rerank({"bm25": docs})
Note:
- Requires the `transformers` and `torch` packages
- Model must be trained/fine-tuned for multi-class text classification
- num_classes is inferred from model.config.num_labels if not provided
- GPU acceleration available if CUDA is installed
- Scores are normalized to [0, 1] via expected value
See Also:
OpenAIDecoderReranker: API-based classification with LLM logprobs.
"""
[docs]
def __init__(
self,
query: str,
topn: int = 10,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
device: Optional[str] = None,
max_length: Optional[int] = 512,
num_classes: Optional[int] = None,
rerank_field: Optional[str] = None,
batch_size: int = 32,
show_progress_bar: bool = False,
fusion_score_weight: float = 1.0,
model_kwargs: Optional[Mapping[str, Any]] = None,
):
super().__init__(
query=query,
topn=topn,
rerank_field=rerank_field,
fusion_score_weight=fusion_score_weight,
)
self._model_name = model_name
self._device = device
self._max_length = max_length
self._batch_size = batch_size
self._show_progress_bar = show_progress_bar
self._num_classes = num_classes # None means auto-infer
self._model_kwargs: Mapping[str, Any] = model_kwargs or {}
self._model: Optional[Any] = None
self._tokenizer: Optional[Any] = None
self._fitted = False
if num_classes is not None and num_classes <= 0:
raise ValueError(f"num_classes must be positive or None, got {num_classes}")
_public_names = (
"model_name",
"device",
"max_length",
"num_classes",
"batch_size",
"show_progress_bar",
"model_kwargs",
)
def _load_model(self) -> None:
"""Load the classification model from HuggingFace.
Raises:
OSError: If the model cannot be loaded from HuggingFace.
ImportError: If the transformers or torch packages are not installed.
ValueError: If the model is not a classification model (no num_labels config).
"""
try:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
except ImportError as e:
raise ImportError(
"ClassificationReranker requires the 'transformers' and 'torch' packages. "
"Install it with: pip install transformers torch"
) from e
try:
# Load tokenizer
self._tokenizer = AutoTokenizer.from_pretrained(
self._model_name, **self._model_kwargs
)
# First, load the model config to get num_labels
from transformers import AutoConfig
config = AutoConfig.from_pretrained(self._model_name, **self._model_kwargs)
# Infer num_classes from model config if not provided
if self._num_classes is None:
if not hasattr(config, "num_labels") or config.num_labels is None:
raise ValueError(
f"Model '{self._model_name}' does not have num_labels configured. "
"Please specify num_classes explicitly or use a classification model."
)
self._num_classes = config.num_labels
# Load model - don't override num_labels, use model's native value
self._model = AutoModelForSequenceClassification.from_pretrained(
self._model_name,
**self._model_kwargs,
)
# Move to device
if self._device:
self._model.to(self._device)
else:
# Auto-detect: use CUDA if available
device = "cuda" if torch.cuda.is_available() else "cpu"
self._model.to(device)
self._model.eval()
except OSError as e:
raise OSError(
f"Failed to load model '{self._model_name}' from HuggingFace. "
"Ensure the model name is correct and you have internet connection."
) from e
[docs]
def fit(self, documents: list[str]) -> "ClassificationReranker":
"""Initialize the reranker by loading the model.
Args:
documents: List of documents (not used, for API compatibility).
Returns:
self: For method chaining.
"""
if self._model is None:
self._load_model()
self._fitted = True
return self
def _compute_scores_batch(
self,
query: str,
documents: list[str],
) -> list[float]:
"""Compute expected value scores for a batch of documents.
Args:
query (str): The search query.
documents (list[str]): List of document texts to score.
Returns:
list[float]: List of expected value scores normalized to [0, 1].
"""
if not documents:
return []
# Auto-load model if not already loaded (lazy loading)
if self._model is None:
self._load_model()
self._fitted = True
import numpy as np
import torch
# Tokenize all pairs (model guaranteed non-None after _load_model)
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded. Call _load_model() first.")
if self._model is None:
raise RuntimeError("Model not loaded. Call _load_model() first.")
pairs = [[query, text] for text in documents]
inputs = self._tokenizer(
pairs,
padding=True,
truncation=True,
max_length=self._max_length,
return_tensors="pt",
)
# Move inputs to device
inputs = {k: v.to(self._model.device) for k, v in inputs.items()}
# Forward pass
try:
with torch.no_grad():
outputs = self._model(**inputs)
logits = outputs.logits
# Apply softmax to get probabilities
probs = torch.softmax(logits, dim=-1).cpu().numpy()
# Validate _num_classes before use
if self._num_classes is None:
raise RuntimeError(
"num_classes not set. Call _infer_num_classes() or set num_classes explicitly."
)
# probs shape: (n_docs, n_classes)
# Compute expected value: E = sum(prob_i * i)
expected_values = np.sum(probs * np.arange(self._num_classes), axis=1)
# Normalize to [0, 1] by dividing by (num_classes - 1)
# If num_classes = 1: use prob[0] directly as score (cross-encoder style)
if self._num_classes > 1:
expected_values = expected_values / (self._num_classes - 1)
else:
# Single class: prob[0] is the relevance score (like a cross-encoder)
expected_values = probs[:, 0]
return expected_values.tolist()
except torch.cuda.OutOfMemoryError as e:
raise MemoryError(
f"Out of memory during inference. "
f"Try reducing batch_size (current: {self._batch_size})."
) from e
except (RuntimeError, ValueError) as e:
raise RuntimeError(f"Model inference failed: {e}") from e