"""OpenAI-compatible API reranker using /classify endpoint for encoder models."""
from __future__ import annotations
import json
import logging
from typing import Any, Optional
import httpx
from .base import BaseCrossEncoderReranker
logger = logging.getLogger(__name__)
[docs]
class OpenAIEncoderReranker(BaseCrossEncoderReranker):
"""Cross-encoder reranker using the /classify endpoint for encoder models.
Uses vLLM's /classify endpoint for encoder models (BERT, RoBERTa).
Computes expected value score from class probabilities:
E[score] = sum(prob_i * i) / (num_classes - 1)
Args:
query (str): Query for reranking. **Required**.
topn (int): Number of top documents to return. Defaults to 10.
base_url (str): API base URL. Defaults to "http://localhost:8000/v1".
api_key (Optional[str]): API key. Defaults to None.
model (str): Model identifier. Defaults to "BAAI/bge-reranker-v2-m3".
num_classes (Optional[int]): Number of classes. Auto-detected if None.
timeout (float): HTTP timeout in seconds. Defaults to 30.0.
rerank_field (Optional[str]): Document field for scoring.
fusion_score_weight (float): Cross-encoder vs fusion weight. Default 1.0.
separator (str): Query-document separator. Defaults to " ".
truncate_prompt_tokens (Optional[int]): Max tokens for truncation.
Example:
>>> from zvec_db.rerankers.cross_encoder import OpenAIEncoderReranker
>>> reranker = OpenAIEncoderReranker(
... query="machine learning",
... num_classes=2,
... base_url="http://localhost:8000",
... )
>>> results = reranker.rerank({"bm25": docs})
Note: Requires vLLM with /classify endpoint enabled.
"""
[docs]
def __init__(
self,
query: str,
topn: int = 10,
base_url: str = "http://localhost:8000/v1",
api_key: Optional[str] = None,
model: str = "BAAI/bge-reranker-v2-m3",
num_classes: Optional[int] = None,
timeout: float = 30.0,
rerank_field: Optional[str] = None,
fusion_score_weight: float = 1.0,
separator: str = " ",
truncate_prompt_tokens: Optional[int] = None,
):
super().__init__(
query=query,
topn=topn,
rerank_field=rerank_field,
fusion_score_weight=fusion_score_weight,
)
self._base_url = base_url.rstrip("/")
self._api_key = api_key
self._model = model
self._num_classes = num_classes
self._timeout = timeout
self._separator = separator
self._truncate_prompt_tokens = truncate_prompt_tokens
_public_names = (
"base_url",
"api_key",
"model",
"num_classes",
"timeout",
"separator",
"truncate_prompt_tokens",
)
def _call_classify_api(self, query: str, documents: list[str]) -> list[float]:
"""Call the /classify endpoint.
Args:
query (str): The search query.
documents (list[str]): List of document texts to score.
Returns:
list[float]: List of expected value scores.
Raises:
RuntimeError: If the API call fails.
"""
headers = {
"Content-Type": "application/json",
}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
# Build query-document pairs with separator
pairs = [f"{query}{self._separator}{doc}" for doc in documents]
payload: dict[str, Any] = {
"model": self._model,
"input": pairs,
}
if self._truncate_prompt_tokens is not None:
payload["truncate_prompt_tokens"] = self._truncate_prompt_tokens
try:
response = httpx.post(
f"{self._base_url}/classify",
headers=headers,
json=payload,
timeout=self._timeout,
)
response.raise_for_status()
data = response.json()
# Extract scores from response
# Response format: {"data": [{"index": 0, "probs": [0.1, 0.9], "label": "1"}, ...]}
results = data.get("data", [])
# Sort by index to ensure correct ordering
sorted_results = sorted(results, key=lambda x: x.get("index", 0))
scores: list[float] = []
for r in sorted_results:
probs = r.get("probs", [])
# Get num_classes from API response if available
api_num_classes = r.get("num_classes")
num_classes = (
api_num_classes
if api_num_classes
else (self._num_classes or len(probs))
)
# Compute expected value: E = sum(prob_i * i)
expected_value = sum(prob * i for i, prob in enumerate(probs))
# Normalize to [0, 1] by dividing by (num_classes - 1)
# If num_classes = 1: use prob[0] directly as score (cross-encoder style)
if num_classes > 1:
expected_value /= num_classes - 1
else:
# Single class: prob[0] is the relevance score (like a cross-encoder)
expected_value = probs[0] if probs else 0.0
scores.append(expected_value)
return scores
except httpx.TimeoutException as e:
raise TimeoutError(
f"Request to /classify endpoint timed out after {self._timeout}s. "
"Consider increasing the timeout parameter."
) from e
except httpx.ConnectError as e:
raise ConnectionError(
f"Failed to connect to {self._base_url}/classify. "
"Ensure the server is running and accessible."
) from e
except httpx.HTTPStatusError as e:
raise RuntimeError(
f"OpenAI /classify API returned status {e.response.status_code}: {e}"
) from e
except (json.JSONDecodeError, KeyError, IndexError) as e:
raise RuntimeError(f"Failed to parse /classify API response: {e}") from e
except httpx.HTTPError as e:
raise RuntimeError(f"OpenAI /classify API call failed: {e}") from e
def _compute_scores_batch(
self,
query: str,
documents: list[str],
) -> list[float]:
"""Compute relevance scores for a batch of documents.
Args:
query (str): The search query.
documents (list[str]): List of document texts to score.
Returns:
list[float]: List of expected value scores.
"""
if not documents:
return []
return self._call_classify_api(query, documents)