Source code for zvec_db.rerankers.fusion.multi_field

from __future__ import annotations

"""Multi-field weighted reranker for combining document fields with different weights.

This module provides a reranker that can combine scores from multiple sources
and multiple fields within each document, allowing fine-grained control over
how different parts of a document contribute to the final score.

Example Usage
-------------
::

    from zvec_db.reranker import MultiFieldWeightedReranker

    # Rerank with field-level weights: title > content > tags
    reranker = MultiFieldWeightedReranker(
        topn=10,
        source_weights={"bm25": 0.6, "dense": 0.4},
        field_weights={"title": 3.0, "content": 1.0, "tags": 0.5}
    )

    results = reranker.rerank({
        "bm25": bm25_results,
        "dense": dense_results
    })
"""

from typing import TYPE_CHECKING, Any, Optional, Union

from zvec.model.doc import Doc
from zvec.typing import MetricType

from ..utils.base_utils import extract_field_score
from .weighted import WeightedReranker

if TYPE_CHECKING:
    from zvec.model.schema import CollectionSchema

# Sentinel value to distinguish "not provided" from "explicitly None"
_UNSET = object()


[docs] class MultiFieldWeightedReranker(WeightedReranker): r"""Reranker that combines scores from multiple sources and document fields. This reranker extends the standard weighted fusion approach by supporting field-level weighting within documents. This is useful when documents have structured fields (e.g., title, content, tags) and you want to weight their contributions differently. The score fusion is computed as: .. math:: \text{score}(d) = \sum_{s \in S} w_s \times \sum_{f \in F} w_f \times \text{norm}(\text{score}_{s,f}(d)) where: - :math:`w_s` is the weight for source :math:`s` - :math:`w_f` is the weight for field :math:`f` - :math:`\text{norm}` is the normalization function (Standard or Bayesian) This is preferred over :class:`NormalizedWeightedReranker` when: * Documents have structured fields with different importance (title > content). * You need fine-grained control over score contributions. * Different fields use different scoring scales. Args: topn (int, optional): Number of top documents to return. Defaults to 10. rerank_field (Optional[str], optional): Ignored. Defaults to None. metric (Optional[MetricType], optional): Metric for RAW scores. Default "cosine" because it's the main use case with zvec/Qdrant. - ``MetricType.COSINE`` : cosine distances [0, 2] - ``MetricType.L2`` : L2 distances - ``MetricType.IP`` : similarities (inner product, including BM25 scores) source_weights (Optional[dict[str, float]], optional): Weight per source key. Sources not listed use weight 1.0. Defaults to None (equal weights). field_weights (Optional[dict[str, float]], optional): Weight per document field. Fields not listed use weight 1.0. The field is retrieved from ``doc.fields`` dictionary. Defaults to None (equal weights for all fields). normalizer_configs (Optional[dict[str, Any]], optional): A mapping of source keys to their specific normalization configurations. default_norm_config (Union[bool, str, dict[str, Any]], optional): The normalization method to use for keys not found in ``normalizer_configs``. Defaults to True (standard normalization). Note: Field scores are expected to be stored in ``doc.fields[field_name]`` as numeric values. If a field is missing or has a non-numeric value, it contributes 0 to the score. Example: >>> reranker = MultiFieldWeightedReranker( ... topn=20, ... source_weights={"bm25": 0.7, "dense": 0.3}, ... field_weights={"title": 3.0, "body": 1.0, "tags": 0.5} ... ) >>> results = reranker.rerank({ ... "bm25": bm25_docs, ... "dense": dense_docs ... }) """
[docs] def __init__( self, topn: int = 10, rerank_field: Optional[str] = None, weights: Optional[dict[str, float]] = None, source_weights: Optional[dict[str, float]] = None, field_weights: Optional[dict[str, float]] = None, normalize: Union[bool, str, dict[str, Any], None] = True, metrics: Optional[ Union[MetricType, dict[str, Union[str, MetricType, None]]] ] = _UNSET, # type: ignore[assignment] schema: Optional[CollectionSchema] = None, ): """Initialize MultiFieldWeightedReranker. Args: topn: Number of top documents to return. rerank_field: Ignored. source_weights: Weight per source. Defaults to equal weights. field_weights: Weight per document field. normalize: Normalization configuration. Can be: - ``True`` (default): Smart default - COSINE → no-op, others → "bayes" - ``str``: Method name ("bayes", "minmax", "percentile", "cosine") - ``dict``: Per-source config, e.g., {"sparse": "bayes", "dense": "cosine"} - ``None`` or ``False``: No normalization (raw scores after conversion) Note: ``"cosine"`` is a no-op (identity) since COSINE scores are already in [0, 1] after conversion ``(2 - score) / 2``. metrics: Metric type(s) for converting distances to similarities. Can be a single MetricType for all sources, or a dict for per-source metrics. If None and schema is provided, metrics are inferred from the schema. **Required** if schema is not provided. schema (Optional[CollectionSchema]): Collection schema to automatically extract metrics from. If provided and metrics is None, metrics are inferred from the schema. Raises: ValueError: If neither metrics nor schema is provided. Example: >>> # Automatic metric detection from collection schema >>> import zvec >>> collection = zvec.open("./my_collection") >>> reranker = MultiFieldWeightedReranker( ... schema=collection.schema, ... source_weights={"bm25": 0.6, "dense": 0.4}, ... field_weights={"title": 3.0, "content": 1.0}, ... normalize=True # Default: bayes for all ... ) """ import warnings # Validation: metrics or schema required # Use sentinel to distinguish "not provided" from "explicitly None" if metrics is _UNSET and schema is None: raise ValueError( "Either 'metrics' or 'schema' must be provided. " "For hybrid search: metrics={'sparse': MetricType.IP, 'dense': MetricType.COSINE}" ) # Convert sentinel to None for internal use if metrics is _UNSET: metrics = None # Handle deprecated source_weights parameter if source_weights is not None: if weights is not None: raise ValueError( "Cannot specify both 'weights' and 'source_weights'. " "Use 'weights' (source_weights is deprecated)." ) warnings.warn( "The 'source_weights' parameter is deprecated. Use 'weights' instead.", DeprecationWarning, stacklevel=2, ) weights = source_weights # Initialize WeightedReranker (which calls RerankFunction.__init__) WeightedReranker.__init__( self, topn=topn, rerank_field=rerank_field, weights=weights, schema=schema, metrics=metrics, ) self._field_weights = field_weights or {} self._normalize = normalize
@staticmethod def _extract_score(doc: Doc) -> float: """Extract score from a document. Args: doc (Doc): Document with a score attribute. Returns: float: Score as a float, or 0.0 if score is None. """ from ..utils.base_utils import extract_score return extract_score(doc) def _extract_field_score(self, doc: Doc, field_name: str) -> float: """Extract score from a specific document field. Delegates to :func:`base_utils.extract_field_score` for the actual logic. Args: doc (Doc): Document with fields attribute. field_name (str): Name of the field to extract score from. Returns: float: Field score as a float, or 0.0 if field is missing or non-numeric. """ return extract_field_score(doc, field_name) def _compute_field_weighted_score(self, doc: Doc) -> float: """Compute weighted score across all fields for a single document. Args: doc (Doc): Document with fields and a base score. Returns: float: Weighted sum of field scores. If no field_weights are set, returns the base document score. """ if not self._field_weights: # No field weights specified, use base document score return self._extract_score(doc) total_score = 0.0 total_weight = 0.0 for field_name, field_weight in self._field_weights.items(): field_score = self._extract_field_score(doc, field_name) if field_score > 0: total_score += field_score * field_weight total_weight += field_weight # If no fields matched, fall back to base document score if total_weight == 0: return self._extract_score(doc) return total_score / total_weight
[docs] def rerank( self, query_results: dict[str, list[Doc]], query: Optional[str] = None ) -> list[Doc]: """Normalize scores per-source and compute weighted fusion with field weighting. This method performs the following steps: 1. Iterates through each source in ``query_results``. 2. For each document, computes a field-weighted score. 3. Applies normalization per source (smart default: COSINE → /2, others → bayes). 4. Filters out documents with a normalized score of 0.0. 5. Delegates to :class:`WeightedReranker` for source-weighted fusion. Args: query_results (dict[str, list[Doc]]): Dictionary mapping source names to lists of documents. Each document should have ``id``, ``score``, and ``fields`` with numeric values for field scoring. Returns: list[Doc]: Reranked documents with weighted normalized scores in the ``score`` field, sorted by descending score. Example: >>> query_results = { ... "sparse_bm25": bm25_docs, ... "dense_cosine": dense_docs ... } >>> reranked = reranker.rerank(query_results) """ normalized_query_results: dict[str, list[Doc]] = {} for key, docs in query_results.items(): if not docs: normalized_query_results[key] = [] continue # 1. Compute field-weighted scores metric = self._get_metric_for_source(key) raw_scores = [self._compute_field_weighted_score(doc) for doc in docs] # 2. Convert scores (ensure higher=better) score_list = [ ( doc.id, self._convert_metric( raw_scores[i], source_key=key, ), ) for i, doc in enumerate(docs) ] # 4. Get normalizer for this source (smart default based on metric) normalizer = self._get_normalizer(key, metric) normalized_scores = normalizer(score_list) score_map = {uid: n_score for uid, n_score in normalized_scores} # 5. Update Docs with the new normalized scores new_docs = [] for doc in docs: norm_score = score_map.get(doc.id, 0.0) # We filter out non-relevant docs (score 0.0) to optimize fusion if norm_score > 0: new_docs.append(doc._replace(score=norm_score)) normalized_query_results[key] = new_docs # Delegate to WeightedReranker.rerank (simple weighted fusion) return super().rerank(normalized_query_results)