Source code for zvec_db.rerankers.base

"""Base classes for reranking in zvec-db."""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from zvec.model.doc import Doc
from zvec.typing import MetricType

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from zvec.model.schema import CollectionSchema


def _extract_metrics_from_schema(
    schema: CollectionSchema,
) -> dict[str, str | MetricType | None]:
    """Extract metric types from a collection schema.

    Args:
        schema (CollectionSchema): The collection schema to extract metrics from.

    Returns:
        dict[str, str | MetricType | None]: Mapping from vector field name to metric type.
        Defaults to MetricType.IP if metric_type is not specified (zvec default).
        None means no conversion needed (equivalent to IP).
    """
    metrics: dict[str, str | MetricType | None] = {}
    for vec in schema.vectors:
        if hasattr(vec, "index_param") and vec.index_param is not None:
            metric_type = getattr(vec.index_param, "metric_type", MetricType.IP)
            metrics[vec.name] = (
                metric_type if metric_type is not None else MetricType.IP
            )
        else:
            metrics[vec.name] = MetricType.IP
    return metrics


def _validate_metric_type(metric: str | MetricType | None, source_name: str = ""):
    """Validate that a metric type is valid (COSINE, IP, L2, or None).

    Args:
        metric: Metric type to validate. None means no conversion (same as IP).
        source_name: Name of the source for error message.

    Raises:
        ValueError: If the metric type is not valid.
    """
    valid_metrics = {
        MetricType.COSINE,
        MetricType.IP,
        MetricType.L2,
        "COSINE",
        "IP",
        "L2",
    }
    if metric is not None and metric not in valid_metrics:
        source_msg = f" for source '{source_name}'" if source_name else ""
        raise ValueError(
            f"Invalid metric type '{metric}'{source_msg}. "
            f"Valid metric types are: COSINE, IP, L2, or None."
        )


[docs] class RerankFunction(ABC): """Abstract base class for reranking search results. Rerankers refine the output of one or more vector queries by applying a secondary scoring strategy. They are used in the ``query()`` method of ``Collection`` via the ``reranker`` parameter. Args: topn (int, optional): Number of top documents to return after reranking. Defaults to 10. rerank_field (str | None, optional): Field name used as input for reranking (e.g., document title or body). Defaults to None. schema (CollectionSchema | None, optional): Collection schema to automatically extract metrics from. If provided and no explicit metrics are given, metric types are inferred from the schema. Defaults to None. metrics (str | MetricType | dict[str, str | MetricType | None] | None, optional): Metric type(s) for converting distances to similarities. Can be: - A single MetricType (e.g., ``MetricType.COSINE``) applied to all sources - A dict mapping source names to their metric type (use ``None`` or ``MetricType.IP`` for sources that don't need conversion, e.g., BM25 scores) - If None and schema is provided, metrics are inferred from the schema (defaults to IP if not specified) - If None and no schema, defaults to IP (no conversion needed) Defaults to None. Note: Subclasses must implement the ``rerank()`` method. """
[docs] def __init__( self, topn: int = 10, rerank_field: str | None = None, schema: CollectionSchema | None = None, metrics: str | MetricType | dict[str, str | MetricType | None] | None = None, ): self._topn = topn self._rerank_field = rerank_field self._schema = schema # Resolve metrics: explicit value > schema inference > empty dict self._metrics = self._resolve_metrics(metrics)
@property def topn(self) -> int: """int: Number of top documents to return after reranking.""" return self._topn @property def rerank_field(self) -> str | None: """str | None: Field name used as reranking input.""" return self._rerank_field @property def schema(self) -> CollectionSchema | None: """CollectionSchema | None: The collection schema if provided.""" return self._schema @property def metrics(self) -> dict[str, str | MetricType | None]: """dict[str, str | MetricType | None]: Per-source metric types.""" return self._metrics def _resolve_metrics( self, metrics: str | MetricType | dict[str, str | MetricType | None] | None, ) -> dict[str, str | MetricType | None]: """Resolve metrics priority: explicit > schema inference > empty dict. Args: metrics: Metric type(s) provided by user. Can be: - A single MetricType applied to all sources (resolved at runtime) - A dict mapping source names to their metric type (None = no conversion) - None to infer from schema or use no conversion (IP for non-vector scores) Returns: dict[str, str | MetricType | None]: Mapping from source to metric type. None means no conversion (equivalent to IP). Raises: ValueError: If an invalid metric type is provided. """ # Case 1: Single MetricType - validate and return if isinstance(metrics, (str, MetricType)): _validate_metric_type(metrics) return {"_global": metrics} # Case 2: Dict of metrics per source - validate each if isinstance(metrics, dict): for source_name, metric in metrics.items(): _validate_metric_type(metric, source_name) return metrics # Case 3: None - try schema inference if metrics is None and self._schema is not None: return self._extract_metrics_from_schema() # Case 4: None and no schema - no conversion return {} def _extract_metrics_from_schema( self, ) -> dict[str, str | MetricType | None]: """Extract metric types from the stored schema. Returns: dict[str, str | MetricType | None]: Mapping from vector field name to metric type. Returns empty dict if no schema is available. Defaults to MetricType.IP if metric_type is not specified. """ if self._schema is None: return {} return _extract_metrics_from_schema(self._schema) def _get_metric_for_source( self, source_key: str | None, ) -> str | MetricType | None: """Get metric type for a specific source. Args: source_key: The source name to get the metric for. Returns: MetricType | str | None: The metric type for this source. None or MetricType.IP means no conversion (for BM25/non-vector scores). If a global metric was set (single MetricType), returns it for all sources. """ # Global metric applies to all sources if "_global" in self._metrics: return self._metrics["_global"] # Per-source metric, default to IP (no conversion needed) return self._metrics.get(source_key, MetricType.IP) # type: ignore[arg-type]
[docs] @abstractmethod def rerank( self, query_results: dict[str, list[Doc]], query: str | None = None, ) -> list[Doc]: """Rerank documents from one or more vector queries. Args: query_results (dict[str, list[Doc]]): Mapping from vector field name to list of retrieved documents (sorted by relevance). query (str | None, optional): The search query. Some rerankers may require this (e.g., CrossEncoder). Defaults to None. Returns: list[Doc]: Reranked list of documents (length <= topn), with updated ``score`` fields. """ ...
[docs] class FusionRerankerBase(RerankFunction): """Base class for fusion-based rerankers combining multiple sources. This class provides shared functionality for rerankers that fuse scores from multiple retrieval sources, including metric conversion and normalization. Conversion formulas (ensure higher=better): - **COSINE**: ``(2 - score) / 2`` - distance [0, 2] -> similarity [0, 1] - **L2**: ``-score`` - inverts order - **IP**: no conversion - already "higher=better" (also for BM25/non-vector scores) Normalization: - **COSINE**: NEVER normalized (conversion already produces [0, 1]) - **Others**: Optional normalization (bayes, minmax, percentile, atan, etc.) """ # ------------------------------------------------------------------------- # Metric Conversion # ------------------------------------------------------------------------- def _convert_metric( self, score: float, source_key: str | None = None, ) -> float: """Convert metric to ensure higher=better (orientation inversion only). This step ONLY inverts the order if needed - it does NOT normalize to [0, 1]. Normalization is applied separately. Args: score (float): Raw score (distance or similarity). source_key (str | None): Source name for metric lookup. Returns: float: Score with correct orientation (higher=better). """ metric = self._get_metric_for_source(source_key) if metric == MetricType.COSINE: return (2.0 - score) / 2.0 if metric == MetricType.L2: return -score # IP: already "higher=better" (also used for BM25/non-vector scores) return score # ------------------------------------------------------------------------- # Normalization # ------------------------------------------------------------------------- @staticmethod def _cosine_normalizer(): """Normalize COSINE scores by dividing by 2. Note: This is rarely needed because _convert_metric already applies (2 - score) / 2 for COSINE, which both inverts orientation AND normalizes to [0, 1]. This normalizer exists only for API consistency. Returns: Callable that takes a score list and returns divided scores. """ def normalize(scores: list[tuple[str, float]]) -> list[tuple[str, float]]: return [(uid, s / 2.0) for uid, s in scores] return normalize @staticmethod def _identity_normalizer(): """Return scores unchanged (no normalization). Returns: Callable that takes a score list and returns it as-is. """ def normalize(scores: list[tuple[str, float]]) -> list[tuple[str, float]]: return scores return normalize def _make_normalizer(self, method: str, metric: MetricType | None = None): """Create a normalizer for a given method. Args: method: Normalization method name ("bayes", "minmax", "percentile", "atan", "default"). metric: Optional metric type for methods that depend on it (e.g., "atan"). Returns: Callable that takes a score list and returns normalized scores. """ # Lazy import to avoid circular dependency from .utils.normalize import Normalize if method == "atan": # atan normalization depends on the metric type def atan_normalize( scores: list[tuple[str, float]], ) -> list[tuple[str, float]]: return Normalize._atan(scores, metric) return atan_normalize return Normalize(method) def _get_normalizer( self, source_key: str, metric: str | MetricType | None, user_config: bool | str | dict | None = None, ): """Get normalizer for a source based on user config and metric. Note: COSINE metric is NEVER normalized - conversion (2-score)/2 already produces scores in [0, 1]. Args: source_key: Source name for per-source config lookup. metric: Metric type for smart default (COSINE -> no norm, others -> bayes). user_config: User's normalize config. If None, uses self._normalize. Returns: Callable that takes a score list and returns normalized scores. """ if user_config is None: user_config = getattr(self, "_normalize", None) # COSINE is never normalized - conversion already produces [0, 1] if metric == MetricType.COSINE: return self._identity_normalizer() # None/False -> no normalization if user_config is None or user_config is False: return self._identity_normalizer() # Dict -> per-source config if isinstance(user_config, dict): source_config = user_config.get(source_key) if source_config is None or source_config is False: return self._identity_normalizer() if isinstance(source_config, str): return self._make_normalizer(source_config, metric) # type: ignore[arg-type] # Lazy import to avoid circular dependency from .utils.normalize import Normalize return Normalize(source_config) # String -> force this method everywhere if isinstance(user_config, str): return self._make_normalizer(user_config, metric) # type: ignore[arg-type] # True -> smart default (bayes for non-COSINE) # Lazy import to avoid circular dependency from .utils.normalize import Normalize return Normalize("bayes") # ------------------------------------------------------------------------- # Score Processing Utility # ------------------------------------------------------------------------- def _process_source_scores( self, source_key: str, docs: list[Doc], ) -> tuple[dict[str, float], dict[str, Doc]]: """Process scores from one source: extract, convert, normalize. This utility method encapsulates the common pattern used by fusion rerankers: 1. Extract raw scores from documents 2. Compute min/max bounds for L2 conversion if needed 3. Convert metric scores to ensure higher=better 4. Apply normalization 5. Return score map and document map Args: source_key: Source name for metric and normalization lookup. docs: List of documents with score attributes. Returns: Tuple of (score_map, id_to_doc) where: - score_map: dict mapping doc.id -> normalized score - id_to_doc: dict mapping doc.id -> Doc object Note: This method is used internally by WeightedReranker and MultiFieldWeightedReranker. """ # 1. Extract raw scores raw_scores = [float(doc.score) if doc.score else 0.0 for doc in docs] # 2. Get metric metric = self._get_metric_for_source(source_key) logger.debug( "Processing source '%s': %d docs, metric=%s", source_key, len(docs), metric ) # 3. Convert scores (ensure higher=better) score_list = [ ( doc.id, self._convert_metric( raw_scores[i], source_key=source_key, ), ) for i, doc in enumerate(docs) ] # 4. Get normalizer for this source normalizer = self._get_normalizer( source_key, metric, getattr(self, "_normalize", None) ) normalized_scores = normalizer(score_list) # 5. Build maps score_map = {uid: n_score for uid, n_score in normalized_scores} id_to_doc = {doc.id: doc for doc in docs} logger.debug( "Source '%s': %d docs processed, %d with non-zero normalized scores", source_key, len(docs), sum(1 for s in score_map.values() if s > 0), ) return score_map, id_to_doc