from __future__ import annotations
"""Weighted fusion rerankers for combining scores from multiple sources.
This module provides rerankers that combine scores from multiple retrieval
sources using weighted sum fusion.
Class
-----
WeightedReranker
Weighted fusion with optional normalization and metric conversion.
Example Usage
-------------
::
from zvec_db.rerankers import WeightedReranker
# Already normalized scores [0, 1], higher=better
reranker = WeightedReranker(weights={"bm25": 0.7, "dense": 0.3})
results = reranker.rerank({"bm25": bm25_docs, "dense": dense_docs})
# Raw scores with automatic normalization (smart default)
reranker = WeightedReranker(
weights={"bm25": 0.7, "dense": 0.3},
normalize=True # COSINE -> no additional norm, others -> bayes
)
results = reranker.rerank({"bm25": bm25_docs, "dense": dense_docs})
# Per-source normalization config
reranker = WeightedReranker(
weights={"bm25": 0.7, "dense": 0.3},
normalize={"bm25": "bayes", "dense": None} # COSINE never normalized
)
results = reranker.rerank({"bm25": bm25_docs, "dense": dense_docs})
"""
import heapq
import logging
from typing import TYPE_CHECKING, Any, Optional, Union
from zvec.model.doc import Doc
from zvec.typing import MetricType
from ..base import FusionRerankerBase
from ..utils.base_utils import _UNSET, extract_score
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from zvec.model.schema import CollectionSchema
[docs]
class WeightedReranker(FusionRerankerBase):
r"""Weighted fusion with optional normalization and metric conversion.
This class combines scores from multiple sources using weighted sum:
.. math::
\text{score}(d) = \sum_{s \in S} \text{norm}(\text{score}_s(d)) \times w_s
where :math:`w_s` is the weight for source :math:`s`.
Features:
- Optional distance->similarity conversion (COSINE, L2, IP)
- Optional normalization per source (bayes, minmax, percentile)
- Smart defaults: COSINE -> no additional normalization, others -> bayes
Distance to similarity conversion:
- **COSINE**: ``(2 - score) / 2`` - distance [0, 2] -> similarity [0, 1]
- **L2**: ``-score`` - inverts order
- **IP**: no conversion (already similarity, including BM25 scores)
Note:
COSINE metric is NEVER additionally normalized - the conversion formula
``(2 - score) / 2`` already produces scores in [0, 1]. Setting normalize
for COSINE sources has no effect.
Normalization methods (applied AFTER conversion, except for COSINE):
- **bayes** (default for non-COSINE): Bayesian sigmoid calibration
- **minmax**: (x - min) / (max - min)
- **percentile**: rank-based normalization
- **default**: index-aware scaling with avgscore
- **atan**: arctan-based normalization ``0.5 + atan(s)/pi``
(assumes scores already converted to "higher=better")
Args:
topn (int, optional): Number of top documents to return. Defaults to 10.
rerank_field (Optional[str], optional): Ignored. Defaults to None.
weights (Optional[dict[str, float]], optional): Weight per source.
Sources not listed use weight 1.0. Defaults to None (equal weights).
normalize (Union[bool, str, dict[str, Any], None], optional): Normalization
configuration. Can be:
- ``True`` (default): Smart default - COSINE -> no norm, others -> "bayes"
- ``str``: Method name ("bayes", "minmax", "percentile", "default", "atan")
- ``dict``: Per-source config, e.g., {"sparse": "bayes", "dense": None}
- ``None`` or ``False``: No normalization (raw scores after conversion)
metrics (Optional[Union[MetricType, dict[str, MetricType]]], optional):
Metric type(s) for converting distances to similarities. Can be:
- A single MetricType (e.g., ``MetricType.COSINE``) applied to all sources
- A dict mapping source names to their metric type
(use ``MetricType.IP`` for sources that don't need conversion, e.g., BM25 scores)
- If None and schema is provided, metrics are inferred from the schema
schema (Optional[CollectionSchema], optional): Collection schema to
automatically extract metrics from. If provided and metrics is None,
metrics are inferred from the schema (defaults to IP).
Raises:
ValueError: If neither metrics nor schema is provided.
Example:
>>> # Already normalized scores [0, 1]
>>> reranker = WeightedReranker(
... weights={"bm25": 0.7, "dense": 0.3}
... )
>>> results = reranker.rerank({
... "bm25": bm25_docs_normalized,
... "dense": dense_docs_normalized
... })
>>> # Raw scores with smart default normalization
>>> reranker = WeightedReranker(
... weights={"bm25": 0.7, "dense": 0.3},
... normalize=True # COSINE -> /2, others -> bayes
... )
>>> results = reranker.rerank({"bm25": bm25_docs, "dense": dense_docs})
>>> # Per-source normalization config
>>> reranker = WeightedReranker(
... weights={"bm25": 0.7, "dense": 0.3},
... normalize={"bm25": "bayes", "dense": "cosine"} # cosine = no-op
... )
>>> # No normalization (raw scores after conversion only)
>>> reranker = WeightedReranker(
... metrics={"bm25": MetricType.IP},
... normalize=None
... )
>>> # Schema auto-detection (recommended with zvec)
>>> import zvec
>>> collection = zvec.open("./my_collection")
>>> reranker = WeightedReranker(
... schema=collection.schema,
... weights={"dense": 0.7, "bm25": 0.3},
... normalize=True
... )
Note:
Distance to similarity conversion is applied before normalization:
- **COSINE**: ``2 - score`` (distance [0,2] -> similarity [0,2])
- **L2**: ``-score`` (inverts order)
- **IP**: no conversion (already similarity, including BM25 scores)
See Also:
RrfReranker: Rank-based fusion (uses ranks, not scores).
"""
[docs]
def __init__(
self,
topn: int = 10,
rerank_field: Optional[str] = None,
weights: Optional[dict[str, float]] = None,
normalize: Union[bool, str, dict[str, Any], None] = True,
metrics: Optional[
Union[MetricType, dict[str, Union[str, MetricType, None]]]
] = _UNSET, # type: ignore[assignment]
schema: Optional[CollectionSchema] = None,
):
"""Initialize WeightedReranker.
Args:
topn: Number of top documents to return.
rerank_field: Ignored.
weights: Weight per source. Defaults to equal weights.
normalize: Normalization configuration. Can be:
- ``True`` (default): Smart default - COSINE -> no-op, others -> "bayes"
- ``"bayes"``: Bayesian sigmoid calibration for all sources
- ``"minmax"``: (x - min) / (max - min) for all sources
- ``"percentile"``: Rank-based normalization for all sources
- ``"cosine"``: No-op (identity). COSINE scores already in [0, 1]
- ``"default"``: Min-max with avgscore
- ``dict``: Per-source config, e.g., {"sparse": "bayes", "dense": "cosine"}
- ``None`` or ``False``: No normalization (raw scores after conversion)
metrics: Metric type(s) for distance-to-similarity conversion.
Can be a single MetricType for all sources, or a dict for per-source metrics.
If None and schema is provided, metrics are inferred from the schema.
schema (Optional[CollectionSchema]): Collection schema to automatic
extract metrics from.
Raises:
ValueError: If neither metrics nor schema is provided.
"""
# Validation: metrics or schema required
if metrics is _UNSET and schema is None:
raise ValueError(
"Either 'metrics' or 'schema' must be provided. "
"For no conversion: metrics=MetricType.IP. "
"For hybrid search: metrics={'sparse': MetricType.IP, 'dense': MetricType.COSINE}"
)
# Convert sentinel to None for internal use
if metrics is _UNSET:
metrics = None
super().__init__(
topn=topn, rerank_field=rerank_field, schema=schema, metrics=metrics
)
self._weights = weights or {}
self._normalize = normalize
self._validate_weights()
def _validate_weights(self):
"""Validate that all weights are non-negative.
Raises:
ValueError: If any weight is negative or not a valid number.
"""
for source, weight in self._weights.items():
if not isinstance(weight, (int, float)):
raise ValueError(
f"Weight for source '{source}' must be a number, got {type(weight).__name__}"
)
if weight < 0:
raise ValueError(
f"Weight for source '{source}' cannot be negative: {weight}. "
"All weights must be >= 0."
)
# Warn if weight is 0 (may indicate configuration error)
if weight == 0:
logger.warning(
"Weight for source '%s' is 0 - this source will be ignored in fusion.",
source,
)
@property
def weights(self) -> dict[str, float]:
return self._weights
@property
def normalize(self) -> Union[bool, str, dict[str, Any], None]:
return self._normalize
@staticmethod
def _extract_score(doc: Doc) -> float:
"""Extract score from a document.
Args:
doc (Doc): Document with a score attribute.
Returns:
float: Score as a float, or 0.0 if score is None.
"""
return extract_score(doc)
[docs]
def rerank(
self, query_results: dict[str, list[Doc]], query: Optional[str] = None
) -> list[Doc]:
"""Convert scores and compute weighted fusion.
Steps:
1. Convert metrics to ensure higher=better:
- COSINE: (2 - score) / 2
- L2: -score (inverts order)
- IP: no conversion
2. Apply normalization per source (COSINE: skipped, others: bayes by default)
3. Filter out documents with normalized score <= 0
4. Compute weighted fusion
Args:
query_results (dict[str, list[Doc]]): Dictionary mapping source
names to lists of documents.
query (Optional[str], optional): Ignored. Defaults to None.
Returns:
list[Doc]: Reranked documents with weighted scores.
Note:
COSINE scores are NOT additionally normalized after conversion,
since (2-score)/2 already produces scores in [0, 1].
"""
del query # WeightedReranker does not use the query
weighted_scores: dict[str, float] = {}
id_to_doc: dict[str, Doc] = {}
for key, docs in query_results.items():
if not docs:
continue
# Process scores: extract, convert, normalize
score_map, doc_map = self._process_source_scores(key, docs)
# Compute weighted fusion
for doc in docs:
doc_id = doc.id
norm_score = score_map.get(doc_id, 0.0)
# Filter out non-relevant docs (score <= 0)
if norm_score > 0:
weight = self._weights.get(key, 1.0)
weighted_scores[doc_id] = (
weighted_scores.get(doc_id, 0.0) + norm_score * weight
)
if doc_id not in id_to_doc:
id_to_doc[doc_id] = doc
else:
logger.debug(
"Filtered doc '%s' from source '%s' (score <= 0)", doc_id, key
)
# Get top documents
top_docs = heapq.nlargest(
self.topn, weighted_scores.items(), key=lambda x: x[1]
)
results: list[Doc] = []
for doc_id, weighted_score in top_docs:
doc = id_to_doc[doc_id]
results.append(doc._replace(score=weighted_score))
logger.debug(
"WeightedReranker: %d docs -> %d results",
sum(len(d) for d in query_results.values()),
len(results),
)
return results