Source code for zvec_db.rerankers.utils.pipeline

"""PipelineReranker for zvec-db.

This module contains the PipelineReranker class that allows chaining
multiple rerankers sequentially.

Example Usage
-------------
.. code-block:: python

    from zvec import create_and_open, VectorQuery
    from zvec_db.rerankers import (
        RrfReranker,
        OpenAICrossEncoderReranker,
        PipelineReranker,
    )

    collection = create_and_open(...)

    # Pipeline: RRF then Cross-Encoder
    pipeline = PipelineReranker(
        rerankers=[
            RrfReranker(topn=50, rank_constant=60),
            OpenAICrossEncoderReranker(
                topn=10,
                model="gpt-4o-mini",
                query="machine learning"
            )
        ]
    )

    results = collection.query(
        vectors=[
            VectorQuery(field_name="bm25", vector=bm25_vec),
            VectorQuery(field_name="dense", vector=dense_vec),
        ],
        topk=10,
        reranker=pipeline
    )
"""

from __future__ import annotations

from typing import Any, Optional

from zvec import ReRanker
from zvec.model.doc import Doc


[docs] class PipelineReranker(ReRanker): """Chain multiple rerankers sequentially. This reranker applies a list of rerankers in sequence, passing the output of one as the input to the next. This is useful for combining different reranking strategies (e.g., RRF followed by cross-encoder). Args: rerankers (list): List of rerankers to apply in order. topn (int, optional): Number of final documents to return. Defaults to 10. rerank_field (Optional[str], optional): Ignored. Defaults to None. Example: >>> pipeline = PipelineReranker([ ... RrfReranker(topn=50, rank_constant=60), ... SentenceTransformerReranker(model_name="ms-marco-MiniLM-L-6-v2", topn=10) ... ]) >>> results = collection.query(..., reranker=pipeline) """
[docs] def __init__( self, rerankers: list, topn: int = 10, rerank_field: Optional[str] = None, ): """Initialize PipelineReranker with a list of rerankers. Args: rerankers (list): List of reranker instances to apply in order. Each reranker must implement the ``rerank()`` method. topn (int, optional): Number of final documents to return. Defaults to 10. rerank_field (Optional[str], optional): Ignored. Defaults to None. Example: >>> pipeline = PipelineReranker([ ... RrfReranker(topn=50, rank_constant=60), ... SentenceTransformerReranker(model_name="ms-marco-MiniLM-L-6-v2", topn=10) ... ]) >>> results = collection.query(..., reranker=pipeline) """ super().__init__(topn=topn, rerank_field=rerank_field) self._rerankers = rerankers
[docs] def rerank( self, query_results: dict[str, list[Doc]], query: Optional[str] = None, ) -> list[Doc]: """Apply rerankers sequentially. Args: query_results (dict[str, list[Doc]]): Results from vector queries. query (Optional[str], optional): The search query. Passed to underlying rerankers. Defaults to None. Returns: list[Doc]: Final re-ranked documents after all rerankers applied. """ current_results: Any = query_results for i, reranker in enumerate(self._rerankers): # Apply reranker with query if supported current_results = reranker.rerank(current_results, query=query) # Prepare for next step if i < len(self._rerankers) - 1: if isinstance(current_results, list): current_results = {"pipeline": current_results} # Extract final list if isinstance(current_results, dict): if "pipeline" in current_results: return current_results["pipeline"] elif "source" in current_results: return current_results["source"] else: return next(iter(current_results.values())) return current_results