Source code for zvec_db.rerankers.cross_encoder.openai

"""OpenAI-compatible API reranker using /rerank and /score endpoints."""

from __future__ import annotations

import logging
from typing import Any, Literal, Optional

import httpx

from ...utils.retry import RetryConfig, retry_with_backoff
from .base import BaseCrossEncoderReranker

logger = logging.getLogger(__name__)


[docs] class OpenAIReranker(BaseCrossEncoderReranker): """Cross-encoder reranker using OpenAI-compatible /rerank or /score endpoints. Uses vLLM's native endpoints: /rerank for query-document scoring, /score for text pair similarity. Both return scores in [0, 1]. Args: query (str): Query for reranking. **Required**. topn (int): Number of top documents to return. Defaults to 10. base_url (str): API base URL. Defaults to "http://localhost:8000/v1". api_key (Optional[str]): API key. Defaults to None. model (str): Model identifier. Defaults to "BAAI/bge-reranker-v2-m3". endpoint (Literal["rerank", "score"]): Endpoint to use. Defaults to "rerank". timeout (float): HTTP timeout in seconds. Defaults to 30.0. rerank_field (Optional[str]): Document field for scoring. Defaults to None. fusion_score_weight (float): Weight for cross-encoder vs fusion scores. 1.0 = pure cross-encoder, 0.0 = pure fusion. Defaults to 1.0. truncate_prompt_tokens (Optional[int]): Max tokens for truncation. max_retries (int, optional): Maximum number of retry attempts for transient failures. Set to 0 to disable retries. Defaults to 3. initial_delay (float, optional): Initial delay before first retry in seconds. Defaults to 1.0. max_delay (float, optional): Maximum delay cap in seconds. Defaults to 60.0. exponential_base (float, optional): Base for exponential backoff. Defaults to 2.0. jitter (float, optional): Random jitter factor (0.0-1.0) to avoid thundering herd. Defaults to 0.1. retry_config (Optional[RetryConfig], optional): Pre-configured retry settings. If provided, overrides individual retry parameters. Defaults to None. Example: >>> from zvec_db.rerankers.cross_encoder import OpenAIReranker >>> reranker = OpenAIReranker( ... query="machine learning", ... endpoint="rerank", ... base_url="http://localhost:8000", ... ) >>> results = reranker.rerank({"bm25": docs}) >>> # With custom retry settings for production >>> reranker = OpenAIReranker( ... query="machine learning", ... max_retries=5, ... initial_delay=2.0, ... max_delay=120.0, ... ) Note: Requires vLLM with /rerank or /score endpoint enabled. """
[docs] def __init__( self, query: str, topn: int = 10, base_url: str = "http://localhost:8000/v1", api_key: Optional[str] = None, model: str = "BAAI/bge-reranker-v2-m3", endpoint: Literal["rerank", "score"] = "rerank", timeout: float = 30.0, rerank_field: Optional[str] = None, fusion_score_weight: float = 1.0, truncate_prompt_tokens: Optional[int] = None, max_retries: int = 3, initial_delay: float = 1.0, max_delay: float = 60.0, exponential_base: float = 2.0, jitter: float = 0.1, retry_config: Optional[RetryConfig] = None, ): super().__init__( query=query, topn=topn, rerank_field=rerank_field, fusion_score_weight=fusion_score_weight, ) self._base_url = base_url.rstrip("/") self._api_key = api_key self._model = model self._endpoint = endpoint self._timeout = timeout self._truncate_prompt_tokens = truncate_prompt_tokens # Use retry_config if provided, otherwise use individual parameters if retry_config is not None: self._retry_config = retry_config else: self._retry_config = RetryConfig( max_retries=max_retries, initial_delay=initial_delay, max_delay=max_delay, exponential_base=exponential_base, jitter=jitter, )
_public_names = ( "base_url", "api_key", "model", "endpoint", "timeout", "truncate_prompt_tokens", ) @retry_with_backoff( max_retries=3, initial_delay=1.0, max_delay=60.0, exponential_base=2.0, jitter=0.1, retry_on_timeout=True, ) def _call_rerank_api(self, query: str, documents: list[str]) -> list[float]: """Call the /rerank endpoint. Args: query (str): The search query. documents (list[str]): List of document texts to score. Returns: list[float]: List of relevance scores. Raises: RuntimeError: If the API call fails after retries. httpx.TimeoutException: On timeout (if retries exhausted). httpx.ConnectError: On connection failure (if retries exhausted). httpx.HTTPStatusError: On HTTP error with retryable status (if retries exhausted). """ headers = { "Content-Type": "application/json", } if self._api_key: headers["Authorization"] = f"Bearer {self._api_key}" payload: dict[str, Any] = { "model": self._model, "query": query, "documents": documents, } if self._truncate_prompt_tokens is not None: payload["truncate_prompt_tokens"] = self._truncate_prompt_tokens response = httpx.post( f"{self._base_url}/rerank", headers=headers, json=payload, timeout=self._timeout, ) response.raise_for_status() data = response.json() # Extract scores from response # Response format: {"results": [{"index": 0, "score": 0.95}, ...]} results = data.get("results", []) # Sort by index to ensure correct ordering sorted_results = sorted(results, key=lambda x: x.get("index", 0)) scores = [r.get("score", 0.0) for r in sorted_results] return scores @retry_with_backoff( max_retries=3, initial_delay=1.0, max_delay=60.0, exponential_base=2.0, jitter=0.1, retry_on_timeout=True, ) def _call_score_api(self, query: str, documents: list[str]) -> list[float]: """Call the /score endpoint. Args: query (str): The reference text (text_1). documents (list[str]): List of candidate texts (text_2). Returns: list[float]: List of similarity scores. Raises: RuntimeError: If the API call fails after retries. httpx.TimeoutException: On timeout (if retries exhausted). httpx.ConnectError: On connection failure (if retries exhausted). httpx.HTTPStatusError: On HTTP error with retryable status (if retries exhausted). """ headers = { "Content-Type": "application/json", } if self._api_key: headers["Authorization"] = f"Bearer {self._api_key}" payload: dict[str, Any] = { "model": self._model, "text_1": query, "text_2": documents, } if self._truncate_prompt_tokens is not None: payload["truncate_prompt_tokens"] = self._truncate_prompt_tokens response = httpx.post( f"{self._base_url}/score", headers=headers, json=payload, timeout=self._timeout, ) response.raise_for_status() data = response.json() # Extract scores from response # Response format: {"results": [{"score": 0.89}, ...]} results = data.get("results", []) scores = [r.get("score", 0.0) for r in results] return scores def _compute_scores_batch( self, query: str, documents: list[str], ) -> list[float]: """Compute relevance scores for a batch of documents. Args: query (str): The search query. documents (list[str]): List of document texts to score. Returns: list[float]: List of relevance scores. """ if not documents: return [] if self._endpoint == "rerank": return self._call_rerank_api(query, documents) elif self._endpoint == "score": return self._call_score_api(query, documents) else: raise ValueError(f"Unknown endpoint: {self._endpoint}")