Source code for zvec_db.rerankers.cross_encoder.base

"""Base class for cross-encoder reranking."""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import List, Optional

from zvec.model.doc import Doc

from ..base import RerankFunction
from ..utils.base_utils import get_document_text

logger = logging.getLogger(__name__)


def _make_property(name: str, doc: str = ""):
    """Factory to create a property for a private attribute.

    Args:
        name: Attribute name without underscore prefix (e.g., "query" for "_query")
        doc: Docstring for the property

    Returns:
        property: A property that returns self._{name}
    """

    def getter(self):
        return getattr(self, f"_{name}")

    return property(getter, doc=doc)


[docs] class CrossEncoderPropertyMixin: """Mixin to auto-generate properties from private attributes. Subclasses should define `_public_names` tuple with attribute names (without underscore prefix) to expose as properties. """ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) # Auto-generate properties for attributes listed in _public_names public_names = getattr(cls, "_public_names", ()) for name in public_names: if not hasattr(cls, name): setattr(cls, name, _make_property(name))
[docs] class BaseCrossEncoderReranker(CrossEncoderPropertyMixin, RerankFunction, ABC): """Abstract base class for cross-encoder reranking. This class provides the common infrastructure for cross-encoder scoring. Subclasses must implement the `_compute_scores_batch()` method to define their scoring strategy. Args: query (str): Query for reranking. **Required**. topn (int, optional): Number of top documents to return after reranking. Defaults to 10. rerank_field (Optional[str], optional): Document field to use for reranking. If None, uses the entire document content. Defaults to None. fusion_score_weight (float, optional): Weight for blending cross-encoder scores with fusion scores. Formula: final_score = cross_encoder_score × weight + fusion_score × (1 - weight) - weight = 1.0 → 100% cross-encoder, 0% fusion (pure cross-encoder, default) - weight = 0.8 → 80% cross-encoder, 20% fusion - weight = 0.5 → 50% cross-encoder, 50% fusion - weight = 0.0 → 0% cross-encoder, 100% fusion (pure fusion) Defaults to 1.0 (pure cross-encoder score). Note: - Subclasses must implement `_compute_scores_batch()` or `_compute_score()` - Cross-encoder reranking is more accurate but slower than score fusion - For large document sets, consider using max_batch_size to limit API calls """
[docs] def __init__( self, query: str, topn: int = 10, rerank_field: Optional[str] = None, fusion_score_weight: float = 1.0, ): super().__init__(topn=topn, rerank_field=rerank_field) self._query = query self._fusion_score_weight = fusion_score_weight
@property def query(self) -> str: """str: Default query for reranking.""" return self._query @property def fusion_score_weight(self) -> float: """float: Weight for blending cross-encoder scores with fusion scores.""" return self._fusion_score_weight @abstractmethod def _compute_scores_batch( self, query: str, documents: List[str], ) -> List[float]: """Compute relevance scores for a batch of documents. Args: query (str): The search query. documents (List[str]): List of document texts to score. Returns: List[float]: List of relevance scores. """ ... def _get_document_text(self, doc: Doc) -> str: """Extract document text for scoring. Delegates to :func:`base_utils.get_document_text` for the actual logic. Args: doc (Doc): Document to extract text from. Returns: str: Document text content. """ return get_document_text(doc, self._rerank_field)
[docs] def rerank( self, query_results: dict[str, list[Doc]], query: Optional[str] = None ) -> list[Doc]: """Rerank documents using cross-encoder scoring. Args: query_results (dict[str, list[Doc]]): Results from one or more vector queries. query (Optional[str], optional): Query for reranking. Overrides constructor value if provided. Returns: list[Doc]: Reranked documents with cross-encoder scores. """ # Use provided query, or fallback to constructor value effective_query = query if query is not None else self._query # Collect all documents (deduplicate by ID) all_docs: list[Doc] = [] seen_ids: set[str] = set() for docs in query_results.values(): for doc in docs: if doc.id not in seen_ids: all_docs.append(doc) seen_ids.add(doc.id) if not all_docs: return [] if effective_query is None: raise ValueError( "CrossEncoderReranker requires a query. " "Provide it via constructor or rerank(query=...) parameter." ) # Extract document texts doc_texts: list[str] = [] doc_mapping: list[Doc] = [] for doc in all_docs: doc_texts.append(get_document_text(doc, self._rerank_field)) doc_mapping.append(doc) # Compute scores scores = self._compute_scores_batch(effective_query, doc_texts) # Blend cross-encoder scores with fusion scores # fusion_score_weight = 1.0 → 100% cross-encoder, 0% fusion # fusion_score_weight = 0.0 → 0% cross-encoder, 100% fusion blended_docs = [] for score, doc in zip(scores, doc_mapping): fusion_score = doc.score if doc.score is not None else 0.0 blended = ( self._fusion_score_weight * score + (1 - self._fusion_score_weight) * fusion_score ) blended_docs.append((blended, doc)) scored_docs = blended_docs # Sort by score descending scored_docs.sort(key=lambda x: x[0], reverse=True) # Return top-n with updated scores results: list[Doc] = [] for score, doc in scored_docs[: self.topn]: results.append(doc._replace(score=score)) return results