Source code for zvec_db.rerankers.cross_encoder.openai_encoder

"""OpenAI-compatible API reranker using /classify endpoint for encoder models."""

from __future__ import annotations

import json
import logging
from typing import Any, Optional

import httpx

from .base import BaseCrossEncoderReranker

logger = logging.getLogger(__name__)


[docs] class OpenAIEncoderReranker(BaseCrossEncoderReranker): """Cross-encoder reranker using the /classify endpoint for encoder models. Uses vLLM's /classify endpoint for encoder models (BERT, RoBERTa). Computes expected value score from class probabilities: E[score] = sum(prob_i * i) / (num_classes - 1) Args: query (str): Query for reranking. **Required**. topn (int): Number of top documents to return. Defaults to 10. base_url (str): API base URL. Defaults to "http://localhost:8000/v1". api_key (Optional[str]): API key. Defaults to None. model (str): Model identifier. Defaults to "BAAI/bge-reranker-v2-m3". num_classes (Optional[int]): Number of classes. Auto-detected if None. timeout (float): HTTP timeout in seconds. Defaults to 30.0. rerank_field (Optional[str]): Document field for scoring. fusion_score_weight (float): Cross-encoder vs fusion weight. Default 1.0. separator (str): Query-document separator. Defaults to " ". truncate_prompt_tokens (Optional[int]): Max tokens for truncation. Example: >>> from zvec_db.rerankers.cross_encoder import OpenAIEncoderReranker >>> reranker = OpenAIEncoderReranker( ... query="machine learning", ... num_classes=2, ... base_url="http://localhost:8000", ... ) >>> results = reranker.rerank({"bm25": docs}) Note: Requires vLLM with /classify endpoint enabled. """
[docs] def __init__( self, query: str, topn: int = 10, base_url: str = "http://localhost:8000/v1", api_key: Optional[str] = None, model: str = "BAAI/bge-reranker-v2-m3", num_classes: Optional[int] = None, timeout: float = 30.0, rerank_field: Optional[str] = None, fusion_score_weight: float = 1.0, separator: str = " ", truncate_prompt_tokens: Optional[int] = None, ): super().__init__( query=query, topn=topn, rerank_field=rerank_field, fusion_score_weight=fusion_score_weight, ) self._base_url = base_url.rstrip("/") self._api_key = api_key self._model = model self._num_classes = num_classes self._timeout = timeout self._separator = separator self._truncate_prompt_tokens = truncate_prompt_tokens
_public_names = ( "base_url", "api_key", "model", "num_classes", "timeout", "separator", "truncate_prompt_tokens", ) def _call_classify_api(self, query: str, documents: list[str]) -> list[float]: """Call the /classify endpoint. Args: query (str): The search query. documents (list[str]): List of document texts to score. Returns: list[float]: List of expected value scores. Raises: RuntimeError: If the API call fails. """ headers = { "Content-Type": "application/json", } if self._api_key: headers["Authorization"] = f"Bearer {self._api_key}" # Build query-document pairs with separator pairs = [f"{query}{self._separator}{doc}" for doc in documents] payload: dict[str, Any] = { "model": self._model, "input": pairs, } if self._truncate_prompt_tokens is not None: payload["truncate_prompt_tokens"] = self._truncate_prompt_tokens try: response = httpx.post( f"{self._base_url}/classify", headers=headers, json=payload, timeout=self._timeout, ) response.raise_for_status() data = response.json() # Extract scores from response # Response format: {"data": [{"index": 0, "probs": [0.1, 0.9], "label": "1"}, ...]} results = data.get("data", []) # Sort by index to ensure correct ordering sorted_results = sorted(results, key=lambda x: x.get("index", 0)) scores: list[float] = [] for r in sorted_results: probs = r.get("probs", []) # Get num_classes from API response if available api_num_classes = r.get("num_classes") num_classes = ( api_num_classes if api_num_classes else (self._num_classes or len(probs)) ) # Compute expected value: E = sum(prob_i * i) expected_value = sum(prob * i for i, prob in enumerate(probs)) # Normalize to [0, 1] by dividing by (num_classes - 1) # If num_classes = 1: use prob[0] directly as score (cross-encoder style) if num_classes > 1: expected_value /= num_classes - 1 else: # Single class: prob[0] is the relevance score (like a cross-encoder) expected_value = probs[0] if probs else 0.0 scores.append(expected_value) return scores except httpx.TimeoutException as e: raise TimeoutError( f"Request to /classify endpoint timed out after {self._timeout}s. " "Consider increasing the timeout parameter." ) from e except httpx.ConnectError as e: raise ConnectionError( f"Failed to connect to {self._base_url}/classify. " "Ensure the server is running and accessible." ) from e except httpx.HTTPStatusError as e: raise RuntimeError( f"OpenAI /classify API returned status {e.response.status_code}: {e}" ) from e except (json.JSONDecodeError, KeyError, IndexError) as e: raise RuntimeError(f"Failed to parse /classify API response: {e}") from e except httpx.HTTPError as e: raise RuntimeError(f"OpenAI /classify API call failed: {e}") from e def _compute_scores_batch( self, query: str, documents: list[str], ) -> list[float]: """Compute relevance scores for a batch of documents. Args: query (str): The search query. documents (list[str]): List of document texts to score. Returns: list[float]: List of expected value scores. """ if not documents: return [] return self._call_classify_api(query, documents)