"""Base class for cross-encoder reranking."""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
from zvec.model.doc import Doc
from ..base import RerankFunction
from ..utils.base_utils import get_document_text
logger = logging.getLogger(__name__)
def _make_property(name: str, doc: str = ""):
"""Factory to create a property for a private attribute.
Args:
name: Attribute name without underscore prefix (e.g., "query" for "_query")
doc: Docstring for the property
Returns:
property: A property that returns self._{name}
"""
def getter(self):
return getattr(self, f"_{name}")
return property(getter, doc=doc)
[docs]
class CrossEncoderPropertyMixin:
"""Mixin to auto-generate properties from private attributes.
Subclasses should define `_public_names` tuple with attribute names
(without underscore prefix) to expose as properties.
"""
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# Auto-generate properties for attributes listed in _public_names
public_names = getattr(cls, "_public_names", ())
for name in public_names:
if not hasattr(cls, name):
setattr(cls, name, _make_property(name))
[docs]
class BaseCrossEncoderReranker(CrossEncoderPropertyMixin, RerankFunction, ABC):
"""Abstract base class for cross-encoder reranking.
This class provides the common infrastructure for cross-encoder scoring.
Subclasses must implement the `_compute_scores_batch()` method to define
their scoring strategy.
Args:
query (str): Query for reranking. **Required**.
topn (int, optional): Number of top documents to return after reranking.
Defaults to 10.
rerank_field (Optional[str], optional): Document field to use for reranking.
If None, uses the entire document content. Defaults to None.
fusion_score_weight (float, optional): Weight for blending cross-encoder scores
with fusion scores.
Formula: final_score = cross_encoder_score × weight + fusion_score × (1 - weight)
- weight = 1.0 → 100% cross-encoder, 0% fusion (pure cross-encoder, default)
- weight = 0.8 → 80% cross-encoder, 20% fusion
- weight = 0.5 → 50% cross-encoder, 50% fusion
- weight = 0.0 → 0% cross-encoder, 100% fusion (pure fusion)
Defaults to 1.0 (pure cross-encoder score).
Note:
- Subclasses must implement `_compute_scores_batch()` or `_compute_score()`
- Cross-encoder reranking is more accurate but slower than score fusion
- For large document sets, consider using max_batch_size to limit API calls
"""
[docs]
def __init__(
self,
query: str,
topn: int = 10,
rerank_field: Optional[str] = None,
fusion_score_weight: float = 1.0,
):
super().__init__(topn=topn, rerank_field=rerank_field)
self._query = query
self._fusion_score_weight = fusion_score_weight
@property
def query(self) -> str:
"""str: Default query for reranking."""
return self._query
@property
def fusion_score_weight(self) -> float:
"""float: Weight for blending cross-encoder scores with fusion scores."""
return self._fusion_score_weight
@abstractmethod
def _compute_scores_batch(
self,
query: str,
documents: List[str],
) -> List[float]:
"""Compute relevance scores for a batch of documents.
Args:
query (str): The search query.
documents (List[str]): List of document texts to score.
Returns:
List[float]: List of relevance scores.
"""
...
def _get_document_text(self, doc: Doc) -> str:
"""Extract document text for scoring.
Delegates to :func:`base_utils.get_document_text` for the actual logic.
Args:
doc (Doc): Document to extract text from.
Returns:
str: Document text content.
"""
return get_document_text(doc, self._rerank_field)
[docs]
def rerank(
self, query_results: dict[str, list[Doc]], query: Optional[str] = None
) -> list[Doc]:
"""Rerank documents using cross-encoder scoring.
Args:
query_results (dict[str, list[Doc]]): Results from one or more
vector queries.
query (Optional[str], optional): Query for reranking. Overrides
constructor value if provided.
Returns:
list[Doc]: Reranked documents with cross-encoder scores.
"""
# Use provided query, or fallback to constructor value
effective_query = query if query is not None else self._query
# Collect all documents (deduplicate by ID)
all_docs: list[Doc] = []
seen_ids: set[str] = set()
for docs in query_results.values():
for doc in docs:
if doc.id not in seen_ids:
all_docs.append(doc)
seen_ids.add(doc.id)
if not all_docs:
return []
if effective_query is None:
raise ValueError(
"CrossEncoderReranker requires a query. "
"Provide it via constructor or rerank(query=...) parameter."
)
# Extract document texts
doc_texts: list[str] = []
doc_mapping: list[Doc] = []
for doc in all_docs:
doc_texts.append(get_document_text(doc, self._rerank_field))
doc_mapping.append(doc)
# Compute scores
scores = self._compute_scores_batch(effective_query, doc_texts)
# Blend cross-encoder scores with fusion scores
# fusion_score_weight = 1.0 → 100% cross-encoder, 0% fusion
# fusion_score_weight = 0.0 → 0% cross-encoder, 100% fusion
blended_docs = []
for score, doc in zip(scores, doc_mapping):
fusion_score = doc.score if doc.score is not None else 0.0
blended = (
self._fusion_score_weight * score
+ (1 - self._fusion_score_weight) * fusion_score
)
blended_docs.append((blended, doc))
scored_docs = blended_docs
# Sort by score descending
scored_docs.sort(key=lambda x: x[0], reverse=True)
# Return top-n with updated scores
results: list[Doc] = []
for score, doc in scored_docs[: self.topn]:
results.append(doc._replace(score=score))
return results