Source code for zvec_db.rerankers.cross_encoder.openai_decoder

"""OpenAI-compatible API reranker using /chat/completions with logprobs."""

from __future__ import annotations

import json
import logging
import math
from typing import Any, Optional

import httpx

from .base import BaseCrossEncoderReranker

logger = logging.getLogger(__name__)


[docs] class OpenAIDecoderReranker(BaseCrossEncoderReranker): """Cross-encoder reranker using LLM logprobs with structured output. Uses /chat/completions with logprobs and regex-constrained output. Computes expected value score from log probabilities: E[score] = sum(prob_i * i) / (num_classes - 1) Args: query (str): Query for reranking. **Required**. topn (int): Number of top documents to return. Defaults to 10. base_url (str): API base URL. Defaults to "http://localhost:8000/v1". api_key (Optional[str]): API key. Defaults to None. model (str): Model identifier. Defaults to "gpt-4o-mini". num_classes (int): Number of classes. Defaults to 2. timeout (float): HTTP timeout in seconds. Defaults to 30.0. max_batch_size (Optional[int]): Max documents per batch. Default None. rerank_field (Optional[str]): Document field for scoring. fusion_score_weight (float): Cross-encoder vs fusion weight. Default 1.0. concurrency (int): Concurrent API calls. Defaults to 4. Example: >>> from zvec_db.rerankers.cross_encoder import OpenAIDecoderReranker >>> reranker = OpenAIDecoderReranker( ... query="machine learning", ... num_classes=2, ... model="gpt-4o-mini", ... ) >>> results = reranker.rerank({"bm25": docs}) Note: Requires model with logprobs support (--enable-logprobs for vLLM). """ MAX_CLASSES = 10 # Maximum classes for regex constraint
[docs] def __init__( self, query: str, topn: int = 10, base_url: str = "http://localhost:8000/v1", api_key: Optional[str] = None, model: str = "gpt-4o-mini", num_classes: int = 2, timeout: float = 30.0, max_batch_size: Optional[int] = None, rerank_field: Optional[str] = None, fusion_score_weight: float = 1.0, concurrency: int = 4, ): super().__init__( query=query, topn=topn, rerank_field=rerank_field, fusion_score_weight=fusion_score_weight, ) self._base_url = base_url.rstrip("/") self._api_key = api_key self._model = model self._num_classes = num_classes self._timeout = timeout self._max_batch_size = max_batch_size self._concurrency = concurrency if num_classes <= 0: raise ValueError(f"num_classes must be positive, got {num_classes}")
_public_names = ( "base_url", "api_key", "model", "num_classes", "timeout", "max_batch_size", "concurrency", ) def _build_prompt(self, query: str, document: str) -> list[dict[str, str]]: """Build the prompt for classification. Args: query (str): Search query. document (str): Document text. Returns: list[dict[str, str]]: Messages for chat completion API. """ classes_str = ", ".join(str(i) for i in range(self._num_classes)) return [ { "role": "system", "content": ( "You are a relevance classifier. " f"Rate the document's relevance to the query using " f"exactly one digit from: {classes_str}." ), }, { "role": "user", "content": ( f"Query: {query}\n\nDocument: {document}\n\n" f"Relevance score (choose one: {classes_str}):" ), }, ] def _call_api(self, messages: list[dict[str, str]]) -> dict[str, Any]: """Call the OpenAI-compatible chat completion API. Args: messages (list[dict[str, str]]): Messages for the chat API. Returns: dict[str, Any]: API response data. Raises: RuntimeError: If the API call fails. """ headers = { "Content-Type": "application/json", } if self._api_key: headers["Authorization"] = f"Bearer {self._api_key}" # Build regex for structured output if self._num_classes <= self.MAX_CLASSES: regex = f"[0-{self._num_classes - 1}]" else: regex = r"\d" payload: dict[str, Any] = { "model": self._model, "messages": messages, "temperature": 0.0, "max_tokens": 1, "logprobs": True, "top_logprobs": self._num_classes, } # Add structured output constraint payload["extra_body"] = {"structured_outputs": {"regex": regex}} try: response = httpx.post( f"{self._base_url}/chat/completions", headers=headers, json=payload, timeout=self._timeout, ) response.raise_for_status() return response.json() except httpx.TimeoutException as e: raise TimeoutError( f"Request to /chat/completions timed out after {self._timeout}s. " "Consider increasing the timeout parameter." ) from e except httpx.ConnectError as e: raise ConnectionError( f"Failed to connect to {self._base_url}/chat/completions. " "Ensure the server is running and accessible." ) from e except httpx.HTTPStatusError as e: raise RuntimeError( f"API returned status {e.response.status_code}: {e}" ) from e except (json.JSONDecodeError, KeyError, IndexError) as e: raise RuntimeError(f"Failed to parse API response: {e}") from e except httpx.HTTPError as e: raise RuntimeError(f"API call failed: {e}") from e def _compute_score(self, query: str, document: str) -> float: """Compute expected value score from logprobs. The expected value is: E = Σ(prob_i × i) / (num_classes - 1) Args: query (str): Search query. document (str): Document text. Returns: float: Expected value score normalized to [0, 1]. """ messages = self._build_prompt(query, document) response = self._call_api(messages) # Extract log probabilities choice = response["choices"][0] logprobs_data = choice.get("logprobs", {}) content = logprobs_data.get("content", []) if not content: logger.warning("No logprobs returned, returning middle value") return 0.5 if self._num_classes > 1 else 0.0 # Handle vLLM logprobs format: content[0].top_logprobs first_token_entry = content[0] if isinstance(first_token_entry, dict) and "top_logprobs" in first_token_entry: top_logprobs = first_token_entry.get("top_logprobs", []) else: top_logprobs = ( first_token_entry if isinstance(first_token_entry, list) else [] ) if not top_logprobs: return 0.5 if self._num_classes > 1 else 0.0 # Map digit tokens to their logprobs class_logprobs: dict[int, float] = {} for token_data in top_logprobs: token = token_data.get("token", "").strip() logprob = token_data.get("logprob", 0.0) # Check if token is a valid digit for our classes if token.isdigit(): digit = int(token) if 0 <= digit < self._num_classes: class_logprobs[digit] = logprob # If no classes matched, return middle value if not class_logprobs: return 0.5 if self._num_classes > 1 else 0.0 # Convert logprobs to probabilities using softmax exp_logprobs = {k: math.exp(v) for k, v in class_logprobs.items()} total = sum(exp_logprobs.values()) if total == 0: return 0.5 if self._num_classes > 1 else 0.0 probs = {k: v / total for k, v in exp_logprobs.items()} # Compute expected value expected_value = sum(prob * digit for digit, prob in probs.items()) # Normalize to [0, 1] by dividing by (num_classes - 1) if self._num_classes > 1: expected_value /= self._num_classes - 1 return expected_value def _compute_scores_batch( self, query: str, documents: list[str], ) -> list[float]: """Compute relevance scores for a batch of documents. Args: query (str): The search query. documents (list[str]): List of document texts to score. Returns: list[float]: List of expected value scores. """ if not documents: return [] if len(documents) <= 1 or self._max_batch_size is None: return [self._compute_score(query, doc) for doc in documents] import concurrent.futures scores: list[float] = [] with concurrent.futures.ThreadPoolExecutor( max_workers=self._concurrency ) as executor: futures = [ executor.submit(self._compute_score, query, doc) for doc in documents ] for future in concurrent.futures.as_completed(futures): try: score = future.result() scores.append(score) except Exception as e: logger.warning(f"Error computing score: {e}") scores.append(0.5) return scores