"""Base classes for reranking in zvec-db."""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from zvec.model.doc import Doc
from zvec.typing import MetricType
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from zvec.model.schema import CollectionSchema
def _extract_metrics_from_schema(
schema: CollectionSchema,
) -> dict[str, str | MetricType | None]:
"""Extract metric types from a collection schema.
Args:
schema (CollectionSchema): The collection schema to extract metrics from.
Returns:
dict[str, str | MetricType | None]: Mapping from vector field name to metric type.
Defaults to MetricType.IP if metric_type is not specified (zvec default).
None means no conversion needed (equivalent to IP).
"""
metrics: dict[str, str | MetricType | None] = {}
for vec in schema.vectors:
if hasattr(vec, "index_param") and vec.index_param is not None:
metric_type = getattr(vec.index_param, "metric_type", MetricType.IP)
metrics[vec.name] = (
metric_type if metric_type is not None else MetricType.IP
)
else:
metrics[vec.name] = MetricType.IP
return metrics
def _validate_metric_type(metric: str | MetricType | None, source_name: str = ""):
"""Validate that a metric type is valid (COSINE, IP, L2, or None).
Args:
metric: Metric type to validate. None means no conversion (same as IP).
source_name: Name of the source for error message.
Raises:
ValueError: If the metric type is not valid.
"""
valid_metrics = {
MetricType.COSINE,
MetricType.IP,
MetricType.L2,
"COSINE",
"IP",
"L2",
}
if metric is not None and metric not in valid_metrics:
source_msg = f" for source '{source_name}'" if source_name else ""
raise ValueError(
f"Invalid metric type '{metric}'{source_msg}. "
f"Valid metric types are: COSINE, IP, L2, or None."
)
[docs]
class RerankFunction(ABC):
"""Abstract base class for reranking search results.
Rerankers refine the output of one or more vector queries by applying
a secondary scoring strategy. They are used in the ``query()`` method of
``Collection`` via the ``reranker`` parameter.
Args:
topn (int, optional): Number of top documents to return after reranking.
Defaults to 10.
rerank_field (str | None, optional): Field name used as input for
reranking (e.g., document title or body). Defaults to None.
schema (CollectionSchema | None, optional): Collection schema to
automatically extract metrics from. If provided and no explicit
metrics are given, metric types are inferred from the schema.
Defaults to None.
metrics (str | MetricType | dict[str, str | MetricType | None] | None, optional):
Metric type(s) for converting distances to similarities. Can be:
- A single MetricType (e.g., ``MetricType.COSINE``) applied
to all sources
- A dict mapping source names to their metric type
(use ``None`` or ``MetricType.IP`` for sources that don't need conversion,
e.g., BM25 scores)
- If None and schema is provided, metrics are inferred
from the schema (defaults to IP if not specified)
- If None and no schema, defaults to IP (no conversion needed)
Defaults to None.
Note:
Subclasses must implement the ``rerank()`` method.
"""
[docs]
def __init__(
self,
topn: int = 10,
rerank_field: str | None = None,
schema: CollectionSchema | None = None,
metrics: str | MetricType | dict[str, str | MetricType | None] | None = None,
):
self._topn = topn
self._rerank_field = rerank_field
self._schema = schema
# Resolve metrics: explicit value > schema inference > empty dict
self._metrics = self._resolve_metrics(metrics)
@property
def topn(self) -> int:
"""int: Number of top documents to return after reranking."""
return self._topn
@property
def rerank_field(self) -> str | None:
"""str | None: Field name used as reranking input."""
return self._rerank_field
@property
def schema(self) -> CollectionSchema | None:
"""CollectionSchema | None: The collection schema if provided."""
return self._schema
@property
def metrics(self) -> dict[str, str | MetricType | None]:
"""dict[str, str | MetricType | None]: Per-source metric types."""
return self._metrics
def _resolve_metrics(
self,
metrics: str | MetricType | dict[str, str | MetricType | None] | None,
) -> dict[str, str | MetricType | None]:
"""Resolve metrics priority: explicit > schema inference > empty dict.
Args:
metrics: Metric type(s) provided by user. Can be:
- A single MetricType applied to all sources (resolved at runtime)
- A dict mapping source names to their metric type (None = no conversion)
- None to infer from schema or use no conversion (IP for non-vector scores)
Returns:
dict[str, str | MetricType | None]: Mapping from source to metric type.
None means no conversion (equivalent to IP).
Raises:
ValueError: If an invalid metric type is provided.
"""
# Case 1: Single MetricType - validate and return
if isinstance(metrics, (str, MetricType)):
_validate_metric_type(metrics)
return {"_global": metrics}
# Case 2: Dict of metrics per source - validate each
if isinstance(metrics, dict):
for source_name, metric in metrics.items():
_validate_metric_type(metric, source_name)
return metrics
# Case 3: None - try schema inference
if metrics is None and self._schema is not None:
return self._extract_metrics_from_schema()
# Case 4: None and no schema - no conversion
return {}
def _extract_metrics_from_schema(
self,
) -> dict[str, str | MetricType | None]:
"""Extract metric types from the stored schema.
Returns:
dict[str, str | MetricType | None]: Mapping from vector field name to metric type.
Returns empty dict if no schema is available.
Defaults to MetricType.IP if metric_type is not specified.
"""
if self._schema is None:
return {}
return _extract_metrics_from_schema(self._schema)
def _get_metric_for_source(
self,
source_key: str | None,
) -> str | MetricType | None:
"""Get metric type for a specific source.
Args:
source_key: The source name to get the metric for.
Returns:
MetricType | str | None: The metric type for this source.
None or MetricType.IP means no conversion (for BM25/non-vector scores).
If a global metric was set (single MetricType), returns it for all sources.
"""
# Global metric applies to all sources
if "_global" in self._metrics:
return self._metrics["_global"]
# Per-source metric, default to IP (no conversion needed)
return self._metrics.get(source_key, MetricType.IP) # type: ignore[arg-type]
[docs]
@abstractmethod
def rerank(
self,
query_results: dict[str, list[Doc]],
query: str | None = None,
) -> list[Doc]:
"""Rerank documents from one or more vector queries.
Args:
query_results (dict[str, list[Doc]]): Mapping from vector field name
to list of retrieved documents (sorted by relevance).
query (str | None, optional): The search query. Some rerankers
may require this (e.g., CrossEncoder). Defaults to None.
Returns:
list[Doc]: Reranked list of documents (length <= topn),
with updated ``score`` fields.
"""
...
[docs]
class FusionRerankerBase(RerankFunction):
"""Base class for fusion-based rerankers combining multiple sources.
This class provides shared functionality for rerankers that fuse scores
from multiple retrieval sources, including metric conversion and normalization.
Conversion formulas (ensure higher=better):
- **COSINE**: ``(2 - score) / 2`` - distance [0, 2] -> similarity [0, 1]
- **L2**: ``-score`` - inverts order
- **IP**: no conversion - already "higher=better" (also for BM25/non-vector scores)
Normalization:
- **COSINE**: NEVER normalized (conversion already produces [0, 1])
- **Others**: Optional normalization (bayes, minmax, percentile, atan, etc.)
"""
# -------------------------------------------------------------------------
# Metric Conversion
# -------------------------------------------------------------------------
def _convert_metric(
self,
score: float,
source_key: str | None = None,
) -> float:
"""Convert metric to ensure higher=better (orientation inversion only).
This step ONLY inverts the order if needed - it does NOT normalize to [0, 1].
Normalization is applied separately.
Args:
score (float): Raw score (distance or similarity).
source_key (str | None): Source name for metric lookup.
Returns:
float: Score with correct orientation (higher=better).
"""
metric = self._get_metric_for_source(source_key)
if metric == MetricType.COSINE:
return (2.0 - score) / 2.0
if metric == MetricType.L2:
return -score
# IP: already "higher=better" (also used for BM25/non-vector scores)
return score
# -------------------------------------------------------------------------
# Normalization
# -------------------------------------------------------------------------
@staticmethod
def _cosine_normalizer():
"""Normalize COSINE scores by dividing by 2.
Note: This is rarely needed because _convert_metric already applies
(2 - score) / 2 for COSINE, which both inverts orientation AND normalizes
to [0, 1]. This normalizer exists only for API consistency.
Returns:
Callable that takes a score list and returns divided scores.
"""
def normalize(scores: list[tuple[str, float]]) -> list[tuple[str, float]]:
return [(uid, s / 2.0) for uid, s in scores]
return normalize
@staticmethod
def _identity_normalizer():
"""Return scores unchanged (no normalization).
Returns:
Callable that takes a score list and returns it as-is.
"""
def normalize(scores: list[tuple[str, float]]) -> list[tuple[str, float]]:
return scores
return normalize
def _make_normalizer(self, method: str, metric: MetricType | None = None):
"""Create a normalizer for a given method.
Args:
method: Normalization method name ("bayes", "minmax", "percentile", "atan", "default").
metric: Optional metric type for methods that depend on it (e.g., "atan").
Returns:
Callable that takes a score list and returns normalized scores.
"""
# Lazy import to avoid circular dependency
from .utils.normalize import Normalize
if method == "atan":
# atan normalization depends on the metric type
def atan_normalize(
scores: list[tuple[str, float]],
) -> list[tuple[str, float]]:
return Normalize._atan(scores, metric)
return atan_normalize
return Normalize(method)
def _get_normalizer(
self,
source_key: str,
metric: str | MetricType | None,
user_config: bool | str | dict | None = None,
):
"""Get normalizer for a source based on user config and metric.
Note: COSINE metric is NEVER normalized - conversion (2-score)/2
already produces scores in [0, 1].
Args:
source_key: Source name for per-source config lookup.
metric: Metric type for smart default (COSINE -> no norm, others -> bayes).
user_config: User's normalize config. If None, uses self._normalize.
Returns:
Callable that takes a score list and returns normalized scores.
"""
if user_config is None:
user_config = getattr(self, "_normalize", None)
# COSINE is never normalized - conversion already produces [0, 1]
if metric == MetricType.COSINE:
return self._identity_normalizer()
# None/False -> no normalization
if user_config is None or user_config is False:
return self._identity_normalizer()
# Dict -> per-source config
if isinstance(user_config, dict):
source_config = user_config.get(source_key)
if source_config is None or source_config is False:
return self._identity_normalizer()
if isinstance(source_config, str):
return self._make_normalizer(source_config, metric) # type: ignore[arg-type]
# Lazy import to avoid circular dependency
from .utils.normalize import Normalize
return Normalize(source_config)
# String -> force this method everywhere
if isinstance(user_config, str):
return self._make_normalizer(user_config, metric) # type: ignore[arg-type]
# True -> smart default (bayes for non-COSINE)
# Lazy import to avoid circular dependency
from .utils.normalize import Normalize
return Normalize("bayes")
# -------------------------------------------------------------------------
# Score Processing Utility
# -------------------------------------------------------------------------
def _process_source_scores(
self,
source_key: str,
docs: list[Doc],
) -> tuple[dict[str, float], dict[str, Doc]]:
"""Process scores from one source: extract, convert, normalize.
This utility method encapsulates the common pattern used by fusion
rerankers:
1. Extract raw scores from documents
2. Compute min/max bounds for L2 conversion if needed
3. Convert metric scores to ensure higher=better
4. Apply normalization
5. Return score map and document map
Args:
source_key: Source name for metric and normalization lookup.
docs: List of documents with score attributes.
Returns:
Tuple of (score_map, id_to_doc) where:
- score_map: dict mapping doc.id -> normalized score
- id_to_doc: dict mapping doc.id -> Doc object
Note:
This method is used internally by WeightedReranker
and MultiFieldWeightedReranker.
"""
# 1. Extract raw scores
raw_scores = [float(doc.score) if doc.score else 0.0 for doc in docs]
# 2. Get metric
metric = self._get_metric_for_source(source_key)
logger.debug(
"Processing source '%s': %d docs, metric=%s", source_key, len(docs), metric
)
# 3. Convert scores (ensure higher=better)
score_list = [
(
doc.id,
self._convert_metric(
raw_scores[i],
source_key=source_key,
),
)
for i, doc in enumerate(docs)
]
# 4. Get normalizer for this source
normalizer = self._get_normalizer(
source_key, metric, getattr(self, "_normalize", None)
)
normalized_scores = normalizer(score_list)
# 5. Build maps
score_map = {uid: n_score for uid, n_score in normalized_scores}
id_to_doc = {doc.id: doc for doc in docs}
logger.debug(
"Source '%s': %d docs processed, %d with non-zero normalized scores",
source_key,
len(docs),
sum(1 for s in score_map.values() if s > 0),
)
return score_map, id_to_doc