zvec_db.rerankers.cross_encoder.classification

Multi-class classification reranking using HuggingFace transformers.

Classes

ClassificationReranker(query[, topn, ...])

Multi-class classification reranker using HuggingFace transformers.

class zvec_db.rerankers.cross_encoder.classification.ClassificationReranker(query, topn=10, model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', device=None, max_length=512, num_classes=None, rerank_field=None, batch_size=32, show_progress_bar=False, fusion_score_weight=1.0, model_kwargs=None)[source]

Multi-class classification reranker using HuggingFace transformers.

This reranker uses a multi-class classification model from HuggingFace (via the transformers library) and computes the expected value of the class distribution:

\[E[\text{score}] = \frac{\sum_{i} prob_i \times i}{num\_classes - 1}\]

The model outputs logits for each class (0, 1, 2, …, num_classes-1). Softmax is applied to get probabilities, then expected value is computed and normalized to [0, 1].

Parameters:
  • query (str) – Query for reranking. Required.

  • topn (int, optional) – Number of top documents to return. Defaults to 10.

  • model_name (str, optional) –

    Classification model name from HuggingFace. Should be a model fine-tuned for text classification with multiple labels. Examples: “cross-encoder/ms-marco-MiniLM-L-6-v2” (binary),

    ”nboost/pt-bert-base-uncased-msmarco” (binary), or any model with config.num_labels set.

  • device (Optional[str], optional) – Device to run model on. “cpu”, “cuda”, or None for auto-detect. Defaults to None.

  • max_length (Optional[int], optional) – Maximum sequence length. Defaults to 512.

  • num_classes (Optional[int], optional) – Number of classes for classification. If None, will be inferred from model.config.num_labels. For binary: 2 (classes 0 and 1) For multi-class: e.g., 5 for 0-4 relevance scale. Defaults to None (auto-infer).

  • rerank_field (Optional[str], optional) – Document field to use for scoring. If None, uses the entire document content. Defaults to None.

  • batch_size (int, optional) – Batch size for inference. Defaults to 32.

  • show_progress_bar (bool, optional) – Show progress bar during inference. Defaults to False.

  • fusion_score_weight (float, optional) –

    Weight for blending cross-encoder scores with fusion scores.

    Formula: final_score = cross_encoder_score × weight + fusion_score × (1 - weight)

    • weight = 1.0 → 100% cross-encoder, 0% fusion (default)

    • weight = 0.8 → 80% cross-encoder, 20% fusion

    • weight = 0.5 → 50% cross-encoder, 50% fusion

    • weight = 0.0 → 0% cross-encoder, 100% fusion

    Defaults to 1.0 (pure cross-encoder score).

  • model_kwargs (Optional[Mapping[str, Any]], optional) – Additional keyword arguments passed to AutoModelForSequenceClassification and AutoTokenizer. Useful for options like: - torch_dtype: Model dtype (torch.float16, torch.bfloat16, “auto” for auto-detection) - trust_remote_code: Trust remote code from HuggingFace Hub - token: HuggingFace API token for private models - revision: Model revision to load - cache_dir: Custom cache directory - local_files_only: Load only local files - attn_implementation: Attention implementation (e.g., “flash_attention_2”, “sdpa”) - load_in_8bit: Enable 8-bit quantization (requires bitsandbytes) - load_in_4bit: Enable 4-bit quantization (requires bitsandbytes) - device_map: Device mapping for distributed loading (e.g., “auto”, “balanced”) Defaults to None (no additional kwargs).

Example

>>> from zvec_db.rerankers.cross_encoder import ClassificationReranker
>>>
>>> # Binary classification (num_classes inferred from model)
>>> reranker = ClassificationReranker(
...     query="machine learning",
...     model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
...     topn=10,
... )
>>>
>>> # Multi-level relevance with explicit num_classes
>>> reranker = ClassificationReranker(
...     query="machine learning",
...     model_name="your-multi-class-classifier",
...     num_classes=5,
...     topn=10,
... )
>>>
>>> reranker.fit([])  # Load model
>>> results = reranker.rerank({"bm25": docs})
>>>
>>> # With model_kwargs for private models or custom options
>>> reranker = ClassificationReranker(
...     query="machine learning",
...     model_name="org/private-model",
...     model_kwargs={"token": "hf_...", "trust_remote_code": True},
... )
>>> reranker.fit([])
>>> results = reranker.rerank({"bm25": docs})
>>>
>>> # With model_kwargs for dtype (float16 for reduced memory)
>>> import torch
>>> reranker = ClassificationReranker(
...     query="machine learning",
...     model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
...     model_kwargs={"torch_dtype": torch.float16},
... )
>>> reranker.fit([])
>>> results = reranker.rerank({"bm25": docs})
>>>
>>> # With model_kwargs for 8-bit quantization (requires bitsandbytes)
>>> reranker = ClassificationReranker(
...     query="machine learning",
...     model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
...     model_kwargs={"load_in_8bit": True},
... )
>>> reranker.fit([])
>>> results = reranker.rerank({"bm25": docs})

Note

  • Requires the transformers and torch packages

  • Model must be trained/fine-tuned for multi-class text classification

  • num_classes is inferred from model.config.num_labels if not provided

  • GPU acceleration available if CUDA is installed

  • Scores are normalized to [0, 1] via expected value

See also

OpenAIDecoderReranker: API-based classification with LLM logprobs.

__init__(query, topn=10, model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', device=None, max_length=512, num_classes=None, rerank_field=None, batch_size=32, show_progress_bar=False, fusion_score_weight=1.0, model_kwargs=None)[source]
Parameters:
  • query (str)

  • topn (int)

  • model_name (str)

  • device (str | None)

  • max_length (int | None)

  • num_classes (int | None)

  • rerank_field (str | None)

  • batch_size (int)

  • show_progress_bar (bool)

  • fusion_score_weight (float)

  • model_kwargs (Mapping[str, Any] | None)

fit(documents)[source]

Initialize the reranker by loading the model.

Parameters:

documents (list[str]) – List of documents (not used, for API compatibility).

Returns:

For method chaining.

Return type:

self

property batch_size
property device
property max_length
property model_kwargs
property model_name
property num_classes
property show_progress_bar