"""OpenAI-compatible API reranker using /chat/completions with logprobs."""
from __future__ import annotations
import json
import logging
import math
from typing import Any, Optional
import httpx
from .base import BaseCrossEncoderReranker
logger = logging.getLogger(__name__)
[docs]
class OpenAIDecoderReranker(BaseCrossEncoderReranker):
"""Cross-encoder reranker using LLM logprobs with structured output.
Uses /chat/completions with logprobs and regex-constrained output.
Computes expected value score from log probabilities:
E[score] = sum(prob_i * i) / (num_classes - 1)
Args:
query (str): Query for reranking. **Required**.
topn (int): Number of top documents to return. Defaults to 10.
base_url (str): API base URL. Defaults to "http://localhost:8000/v1".
api_key (Optional[str]): API key. Defaults to None.
model (str): Model identifier. Defaults to "gpt-4o-mini".
num_classes (int): Number of classes. Defaults to 2.
timeout (float): HTTP timeout in seconds. Defaults to 30.0.
max_batch_size (Optional[int]): Max documents per batch. Default None.
rerank_field (Optional[str]): Document field for scoring.
fusion_score_weight (float): Cross-encoder vs fusion weight. Default 1.0.
concurrency (int): Concurrent API calls. Defaults to 4.
Example:
>>> from zvec_db.rerankers.cross_encoder import OpenAIDecoderReranker
>>> reranker = OpenAIDecoderReranker(
... query="machine learning",
... num_classes=2,
... model="gpt-4o-mini",
... )
>>> results = reranker.rerank({"bm25": docs})
Note: Requires model with logprobs support (--enable-logprobs for vLLM).
"""
MAX_CLASSES = 10 # Maximum classes for regex constraint
[docs]
def __init__(
self,
query: str,
topn: int = 10,
base_url: str = "http://localhost:8000/v1",
api_key: Optional[str] = None,
model: str = "gpt-4o-mini",
num_classes: int = 2,
timeout: float = 30.0,
max_batch_size: Optional[int] = None,
rerank_field: Optional[str] = None,
fusion_score_weight: float = 1.0,
concurrency: int = 4,
):
super().__init__(
query=query,
topn=topn,
rerank_field=rerank_field,
fusion_score_weight=fusion_score_weight,
)
self._base_url = base_url.rstrip("/")
self._api_key = api_key
self._model = model
self._num_classes = num_classes
self._timeout = timeout
self._max_batch_size = max_batch_size
self._concurrency = concurrency
if num_classes <= 0:
raise ValueError(f"num_classes must be positive, got {num_classes}")
_public_names = (
"base_url",
"api_key",
"model",
"num_classes",
"timeout",
"max_batch_size",
"concurrency",
)
def _build_prompt(self, query: str, document: str) -> list[dict[str, str]]:
"""Build the prompt for classification.
Args:
query (str): Search query.
document (str): Document text.
Returns:
list[dict[str, str]]: Messages for chat completion API.
"""
classes_str = ", ".join(str(i) for i in range(self._num_classes))
return [
{
"role": "system",
"content": (
"You are a relevance classifier. "
f"Rate the document's relevance to the query using "
f"exactly one digit from: {classes_str}."
),
},
{
"role": "user",
"content": (
f"Query: {query}\n\nDocument: {document}\n\n"
f"Relevance score (choose one: {classes_str}):"
),
},
]
def _call_api(self, messages: list[dict[str, str]]) -> dict[str, Any]:
"""Call the OpenAI-compatible chat completion API.
Args:
messages (list[dict[str, str]]): Messages for the chat API.
Returns:
dict[str, Any]: API response data.
Raises:
RuntimeError: If the API call fails.
"""
headers = {
"Content-Type": "application/json",
}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
# Build regex for structured output
if self._num_classes <= self.MAX_CLASSES:
regex = f"[0-{self._num_classes - 1}]"
else:
regex = r"\d"
payload: dict[str, Any] = {
"model": self._model,
"messages": messages,
"temperature": 0.0,
"max_tokens": 1,
"logprobs": True,
"top_logprobs": self._num_classes,
}
# Add structured output constraint
payload["extra_body"] = {"structured_outputs": {"regex": regex}}
try:
response = httpx.post(
f"{self._base_url}/chat/completions",
headers=headers,
json=payload,
timeout=self._timeout,
)
response.raise_for_status()
return response.json()
except httpx.TimeoutException as e:
raise TimeoutError(
f"Request to /chat/completions timed out after {self._timeout}s. "
"Consider increasing the timeout parameter."
) from e
except httpx.ConnectError as e:
raise ConnectionError(
f"Failed to connect to {self._base_url}/chat/completions. "
"Ensure the server is running and accessible."
) from e
except httpx.HTTPStatusError as e:
raise RuntimeError(
f"API returned status {e.response.status_code}: {e}"
) from e
except (json.JSONDecodeError, KeyError, IndexError) as e:
raise RuntimeError(f"Failed to parse API response: {e}") from e
except httpx.HTTPError as e:
raise RuntimeError(f"API call failed: {e}") from e
def _compute_score(self, query: str, document: str) -> float:
"""Compute expected value score from logprobs.
The expected value is: E = Σ(prob_i × i) / (num_classes - 1)
Args:
query (str): Search query.
document (str): Document text.
Returns:
float: Expected value score normalized to [0, 1].
"""
messages = self._build_prompt(query, document)
response = self._call_api(messages)
# Extract log probabilities
choice = response["choices"][0]
logprobs_data = choice.get("logprobs", {})
content = logprobs_data.get("content", [])
if not content:
logger.warning("No logprobs returned, returning middle value")
return 0.5 if self._num_classes > 1 else 0.0
# Handle vLLM logprobs format: content[0].top_logprobs
first_token_entry = content[0]
if isinstance(first_token_entry, dict) and "top_logprobs" in first_token_entry:
top_logprobs = first_token_entry.get("top_logprobs", [])
else:
top_logprobs = (
first_token_entry if isinstance(first_token_entry, list) else []
)
if not top_logprobs:
return 0.5 if self._num_classes > 1 else 0.0
# Map digit tokens to their logprobs
class_logprobs: dict[int, float] = {}
for token_data in top_logprobs:
token = token_data.get("token", "").strip()
logprob = token_data.get("logprob", 0.0)
# Check if token is a valid digit for our classes
if token.isdigit():
digit = int(token)
if 0 <= digit < self._num_classes:
class_logprobs[digit] = logprob
# If no classes matched, return middle value
if not class_logprobs:
return 0.5 if self._num_classes > 1 else 0.0
# Convert logprobs to probabilities using softmax
exp_logprobs = {k: math.exp(v) for k, v in class_logprobs.items()}
total = sum(exp_logprobs.values())
if total == 0:
return 0.5 if self._num_classes > 1 else 0.0
probs = {k: v / total for k, v in exp_logprobs.items()}
# Compute expected value
expected_value = sum(prob * digit for digit, prob in probs.items())
# Normalize to [0, 1] by dividing by (num_classes - 1)
if self._num_classes > 1:
expected_value /= self._num_classes - 1
return expected_value
def _compute_scores_batch(
self,
query: str,
documents: list[str],
) -> list[float]:
"""Compute relevance scores for a batch of documents.
Args:
query (str): The search query.
documents (list[str]): List of document texts to score.
Returns:
list[float]: List of expected value scores.
"""
if not documents:
return []
if len(documents) <= 1 or self._max_batch_size is None:
return [self._compute_score(query, doc) for doc in documents]
import concurrent.futures
scores: list[float] = []
with concurrent.futures.ThreadPoolExecutor(
max_workers=self._concurrency
) as executor:
futures = [
executor.submit(self._compute_score, query, doc) for doc in documents
]
for future in concurrent.futures.as_completed(futures):
try:
score = future.result()
scores.append(score)
except Exception as e:
logger.warning(f"Error computing score: {e}")
scores.append(0.5)
return scores