"""PipelineReranker for zvec-db.
This module contains the PipelineReranker class that allows chaining
multiple rerankers sequentially.
Example Usage
-------------
.. code-block:: python
from zvec import create_and_open, VectorQuery
from zvec_db.rerankers import (
RrfReranker,
OpenAICrossEncoderReranker,
PipelineReranker,
)
collection = create_and_open(...)
# Pipeline: RRF then Cross-Encoder
pipeline = PipelineReranker(
rerankers=[
RrfReranker(topn=50, rank_constant=60),
OpenAICrossEncoderReranker(
topn=10,
model="gpt-4o-mini",
query="machine learning"
)
]
)
results = collection.query(
vectors=[
VectorQuery(field_name="bm25", vector=bm25_vec),
VectorQuery(field_name="dense", vector=dense_vec),
],
topk=10,
reranker=pipeline
)
"""
from __future__ import annotations
from typing import Any, Optional
from zvec import ReRanker
from zvec.model.doc import Doc
[docs]
class PipelineReranker(ReRanker):
"""Chain multiple rerankers sequentially.
This reranker applies a list of rerankers in sequence, passing the output
of one as the input to the next. This is useful for combining different
reranking strategies (e.g., RRF followed by cross-encoder).
Args:
rerankers (list): List of rerankers to apply in order.
topn (int, optional): Number of final documents to return. Defaults to 10.
rerank_field (Optional[str], optional): Ignored. Defaults to None.
Example:
>>> pipeline = PipelineReranker([
... RrfReranker(topn=50, rank_constant=60),
... SentenceTransformerReranker(model_name="ms-marco-MiniLM-L-6-v2", topn=10)
... ])
>>> results = collection.query(..., reranker=pipeline)
"""
[docs]
def __init__(
self,
rerankers: list,
topn: int = 10,
rerank_field: Optional[str] = None,
):
"""Initialize PipelineReranker with a list of rerankers.
Args:
rerankers (list): List of reranker instances to apply in order.
Each reranker must implement the ``rerank()`` method.
topn (int, optional): Number of final documents to return. Defaults to 10.
rerank_field (Optional[str], optional): Ignored. Defaults to None.
Example:
>>> pipeline = PipelineReranker([
... RrfReranker(topn=50, rank_constant=60),
... SentenceTransformerReranker(model_name="ms-marco-MiniLM-L-6-v2", topn=10)
... ])
>>> results = collection.query(..., reranker=pipeline)
"""
super().__init__(topn=topn, rerank_field=rerank_field)
self._rerankers = rerankers
[docs]
def rerank(
self,
query_results: dict[str, list[Doc]],
query: Optional[str] = None,
) -> list[Doc]:
"""Apply rerankers sequentially.
Args:
query_results (dict[str, list[Doc]]): Results from vector queries.
query (Optional[str], optional): The search query. Passed to
underlying rerankers. Defaults to None.
Returns:
list[Doc]: Final re-ranked documents after all rerankers applied.
"""
current_results: Any = query_results
for i, reranker in enumerate(self._rerankers):
# Apply reranker with query if supported
current_results = reranker.rerank(current_results, query=query)
# Prepare for next step
if i < len(self._rerankers) - 1:
if isinstance(current_results, list):
current_results = {"pipeline": current_results}
# Extract final list
if isinstance(current_results, dict):
if "pipeline" in current_results:
return current_results["pipeline"]
elif "source" in current_results:
return current_results["source"]
else:
return next(iter(current_results.values()))
return current_results