"""OpenAI-compatible API reranker using /rerank and /score endpoints."""
from __future__ import annotations
import logging
from typing import Any, Literal, Optional
import httpx
from ...utils.retry import RetryConfig, retry_with_backoff
from .base import BaseCrossEncoderReranker
logger = logging.getLogger(__name__)
[docs]
class OpenAIReranker(BaseCrossEncoderReranker):
"""Cross-encoder reranker using OpenAI-compatible /rerank or /score endpoints.
Uses vLLM's native endpoints: /rerank for query-document scoring,
/score for text pair similarity. Both return scores in [0, 1].
Args:
query (str): Query for reranking. **Required**.
topn (int): Number of top documents to return. Defaults to 10.
base_url (str): API base URL. Defaults to "http://localhost:8000/v1".
api_key (Optional[str]): API key. Defaults to None.
model (str): Model identifier. Defaults to "BAAI/bge-reranker-v2-m3".
endpoint (Literal["rerank", "score"]): Endpoint to use. Defaults to "rerank".
timeout (float): HTTP timeout in seconds. Defaults to 30.0.
rerank_field (Optional[str]): Document field for scoring. Defaults to None.
fusion_score_weight (float): Weight for cross-encoder vs fusion scores.
1.0 = pure cross-encoder, 0.0 = pure fusion. Defaults to 1.0.
truncate_prompt_tokens (Optional[int]): Max tokens for truncation.
max_retries (int, optional): Maximum number of retry attempts for transient failures.
Set to 0 to disable retries. Defaults to 3.
initial_delay (float, optional): Initial delay before first retry in seconds.
Defaults to 1.0.
max_delay (float, optional): Maximum delay cap in seconds. Defaults to 60.0.
exponential_base (float, optional): Base for exponential backoff. Defaults to 2.0.
jitter (float, optional): Random jitter factor (0.0-1.0) to avoid thundering herd.
Defaults to 0.1.
retry_config (Optional[RetryConfig], optional): Pre-configured retry settings.
If provided, overrides individual retry parameters. Defaults to None.
Example:
>>> from zvec_db.rerankers.cross_encoder import OpenAIReranker
>>> reranker = OpenAIReranker(
... query="machine learning",
... endpoint="rerank",
... base_url="http://localhost:8000",
... )
>>> results = reranker.rerank({"bm25": docs})
>>> # With custom retry settings for production
>>> reranker = OpenAIReranker(
... query="machine learning",
... max_retries=5,
... initial_delay=2.0,
... max_delay=120.0,
... )
Note: Requires vLLM with /rerank or /score endpoint enabled.
"""
[docs]
def __init__(
self,
query: str,
topn: int = 10,
base_url: str = "http://localhost:8000/v1",
api_key: Optional[str] = None,
model: str = "BAAI/bge-reranker-v2-m3",
endpoint: Literal["rerank", "score"] = "rerank",
timeout: float = 30.0,
rerank_field: Optional[str] = None,
fusion_score_weight: float = 1.0,
truncate_prompt_tokens: Optional[int] = None,
max_retries: int = 3,
initial_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
jitter: float = 0.1,
retry_config: Optional[RetryConfig] = None,
):
super().__init__(
query=query,
topn=topn,
rerank_field=rerank_field,
fusion_score_weight=fusion_score_weight,
)
self._base_url = base_url.rstrip("/")
self._api_key = api_key
self._model = model
self._endpoint = endpoint
self._timeout = timeout
self._truncate_prompt_tokens = truncate_prompt_tokens
# Use retry_config if provided, otherwise use individual parameters
if retry_config is not None:
self._retry_config = retry_config
else:
self._retry_config = RetryConfig(
max_retries=max_retries,
initial_delay=initial_delay,
max_delay=max_delay,
exponential_base=exponential_base,
jitter=jitter,
)
_public_names = (
"base_url",
"api_key",
"model",
"endpoint",
"timeout",
"truncate_prompt_tokens",
)
@retry_with_backoff(
max_retries=3,
initial_delay=1.0,
max_delay=60.0,
exponential_base=2.0,
jitter=0.1,
retry_on_timeout=True,
)
def _call_rerank_api(self, query: str, documents: list[str]) -> list[float]:
"""Call the /rerank endpoint.
Args:
query (str): The search query.
documents (list[str]): List of document texts to score.
Returns:
list[float]: List of relevance scores.
Raises:
RuntimeError: If the API call fails after retries.
httpx.TimeoutException: On timeout (if retries exhausted).
httpx.ConnectError: On connection failure (if retries exhausted).
httpx.HTTPStatusError: On HTTP error with retryable status (if retries exhausted).
"""
headers = {
"Content-Type": "application/json",
}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
payload: dict[str, Any] = {
"model": self._model,
"query": query,
"documents": documents,
}
if self._truncate_prompt_tokens is not None:
payload["truncate_prompt_tokens"] = self._truncate_prompt_tokens
response = httpx.post(
f"{self._base_url}/rerank",
headers=headers,
json=payload,
timeout=self._timeout,
)
response.raise_for_status()
data = response.json()
# Extract scores from response
# Response format: {"results": [{"index": 0, "score": 0.95}, ...]}
results = data.get("results", [])
# Sort by index to ensure correct ordering
sorted_results = sorted(results, key=lambda x: x.get("index", 0))
scores = [r.get("score", 0.0) for r in sorted_results]
return scores
@retry_with_backoff(
max_retries=3,
initial_delay=1.0,
max_delay=60.0,
exponential_base=2.0,
jitter=0.1,
retry_on_timeout=True,
)
def _call_score_api(self, query: str, documents: list[str]) -> list[float]:
"""Call the /score endpoint.
Args:
query (str): The reference text (text_1).
documents (list[str]): List of candidate texts (text_2).
Returns:
list[float]: List of similarity scores.
Raises:
RuntimeError: If the API call fails after retries.
httpx.TimeoutException: On timeout (if retries exhausted).
httpx.ConnectError: On connection failure (if retries exhausted).
httpx.HTTPStatusError: On HTTP error with retryable status (if retries exhausted).
"""
headers = {
"Content-Type": "application/json",
}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
payload: dict[str, Any] = {
"model": self._model,
"text_1": query,
"text_2": documents,
}
if self._truncate_prompt_tokens is not None:
payload["truncate_prompt_tokens"] = self._truncate_prompt_tokens
response = httpx.post(
f"{self._base_url}/score",
headers=headers,
json=payload,
timeout=self._timeout,
)
response.raise_for_status()
data = response.json()
# Extract scores from response
# Response format: {"results": [{"score": 0.89}, ...]}
results = data.get("results", [])
scores = [r.get("score", 0.0) for r in results]
return scores
def _compute_scores_batch(
self,
query: str,
documents: list[str],
) -> list[float]:
"""Compute relevance scores for a batch of documents.
Args:
query (str): The search query.
documents (list[str]): List of document texts to score.
Returns:
list[float]: List of relevance scores.
"""
if not documents:
return []
if self._endpoint == "rerank":
return self._call_rerank_api(query, documents)
elif self._endpoint == "score":
return self._call_score_api(query, documents)
else:
raise ValueError(f"Unknown endpoint: {self._endpoint}")