zvec_db.rerankers.cross_encoder.classification
Multi-class classification reranking using HuggingFace transformers.
Classes
|
Multi-class classification reranker using HuggingFace transformers. |
- class zvec_db.rerankers.cross_encoder.classification.ClassificationReranker(query, topn=10, model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', device=None, max_length=512, num_classes=None, rerank_field=None, batch_size=32, show_progress_bar=False, fusion_score_weight=1.0, model_kwargs=None)[source]
Multi-class classification reranker using HuggingFace transformers.
This reranker uses a multi-class classification model from HuggingFace (via the transformers library) and computes the expected value of the class distribution:
\[E[\text{score}] = \frac{\sum_{i} prob_i \times i}{num\_classes - 1}\]The model outputs logits for each class (0, 1, 2, …, num_classes-1). Softmax is applied to get probabilities, then expected value is computed and normalized to [0, 1].
- Parameters:
query (str) – Query for reranking. Required.
topn (int, optional) – Number of top documents to return. Defaults to 10.
model_name (str, optional) –
Classification model name from HuggingFace. Should be a model fine-tuned for text classification with multiple labels. Examples: “cross-encoder/ms-marco-MiniLM-L-6-v2” (binary),
”nboost/pt-bert-base-uncased-msmarco” (binary), or any model with config.num_labels set.
device (Optional[str], optional) – Device to run model on. “cpu”, “cuda”, or None for auto-detect. Defaults to None.
max_length (Optional[int], optional) – Maximum sequence length. Defaults to 512.
num_classes (Optional[int], optional) – Number of classes for classification. If None, will be inferred from model.config.num_labels. For binary: 2 (classes 0 and 1) For multi-class: e.g., 5 for 0-4 relevance scale. Defaults to None (auto-infer).
rerank_field (Optional[str], optional) – Document field to use for scoring. If None, uses the entire document content. Defaults to None.
batch_size (int, optional) – Batch size for inference. Defaults to 32.
show_progress_bar (bool, optional) – Show progress bar during inference. Defaults to False.
fusion_score_weight (float, optional) –
Weight for blending cross-encoder scores with fusion scores.
Formula: final_score = cross_encoder_score × weight + fusion_score × (1 - weight)
weight = 1.0 → 100% cross-encoder, 0% fusion (default)
weight = 0.8 → 80% cross-encoder, 20% fusion
weight = 0.5 → 50% cross-encoder, 50% fusion
weight = 0.0 → 0% cross-encoder, 100% fusion
Defaults to 1.0 (pure cross-encoder score).
model_kwargs (Optional[Mapping[str, Any]], optional) – Additional keyword arguments passed to AutoModelForSequenceClassification and AutoTokenizer. Useful for options like: - torch_dtype: Model dtype (torch.float16, torch.bfloat16, “auto” for auto-detection) - trust_remote_code: Trust remote code from HuggingFace Hub - token: HuggingFace API token for private models - revision: Model revision to load - cache_dir: Custom cache directory - local_files_only: Load only local files - attn_implementation: Attention implementation (e.g., “flash_attention_2”, “sdpa”) - load_in_8bit: Enable 8-bit quantization (requires bitsandbytes) - load_in_4bit: Enable 4-bit quantization (requires bitsandbytes) - device_map: Device mapping for distributed loading (e.g., “auto”, “balanced”) Defaults to None (no additional kwargs).
Example
>>> from zvec_db.rerankers.cross_encoder import ClassificationReranker >>> >>> # Binary classification (num_classes inferred from model) >>> reranker = ClassificationReranker( ... query="machine learning", ... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", ... topn=10, ... ) >>> >>> # Multi-level relevance with explicit num_classes >>> reranker = ClassificationReranker( ... query="machine learning", ... model_name="your-multi-class-classifier", ... num_classes=5, ... topn=10, ... ) >>> >>> reranker.fit([]) # Load model >>> results = reranker.rerank({"bm25": docs}) >>> >>> # With model_kwargs for private models or custom options >>> reranker = ClassificationReranker( ... query="machine learning", ... model_name="org/private-model", ... model_kwargs={"token": "hf_...", "trust_remote_code": True}, ... ) >>> reranker.fit([]) >>> results = reranker.rerank({"bm25": docs}) >>> >>> # With model_kwargs for dtype (float16 for reduced memory) >>> import torch >>> reranker = ClassificationReranker( ... query="machine learning", ... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", ... model_kwargs={"torch_dtype": torch.float16}, ... ) >>> reranker.fit([]) >>> results = reranker.rerank({"bm25": docs}) >>> >>> # With model_kwargs for 8-bit quantization (requires bitsandbytes) >>> reranker = ClassificationReranker( ... query="machine learning", ... model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", ... model_kwargs={"load_in_8bit": True}, ... ) >>> reranker.fit([]) >>> results = reranker.rerank({"bm25": docs})
Note
Requires the transformers and torch packages
Model must be trained/fine-tuned for multi-class text classification
num_classes is inferred from model.config.num_labels if not provided
GPU acceleration available if CUDA is installed
Scores are normalized to [0, 1] via expected value
See also
OpenAIDecoderReranker: API-based classification with LLM logprobs.
- __init__(query, topn=10, model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', device=None, max_length=512, num_classes=None, rerank_field=None, batch_size=32, show_progress_bar=False, fusion_score_weight=1.0, model_kwargs=None)[source]
- property batch_size
- property device
- property max_length
- property model_kwargs
- property model_name
- property num_classes
- property show_progress_bar