Source code for zvec_db.embedders.base

"""Base classes and type definitions for sparse embedding models.

This module provides the abstract base class :class:`BaseSparseEmbedder` which
defines a common interface for all sparse embedding models in this package.
It handles tokenization, model persistence, and conversion to zvec-compatible
formats.

Constants
---------
DEFAULT_MAX_FEATURES : int
    Default maximum number of features (non-zero elements) to retain per
    document. Set to 8192 (2^13) as a power of 2 for memory alignment,
    balancing between vocabulary coverage and memory efficiency.
"""

import threading
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Optional, Tuple

import joblib
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator

from ..utils.cache import LRUCacheMixin

if TYPE_CHECKING:
    from zvec_db.preprocessing.config import NormalizationConfig

# Type aliases for zvec compatibility
SparseVector = dict[int, float]
"""A sparse vector represented as a dictionary mapping feature indices to values."""

ExtendedList = list[str] | list[list[str]]
"""A corpus that can be either raw strings or pre-tokenized lists."""

StrExtendedList = str | list[str] | list[list[str]]
"""Input text that can be a single document or a batch."""


def _identity(x: str) -> str:
    """Identity function used as tokenizer/preprocessor placeholder.

    Defined at module level to be pickleable for model persistence.
    """
    return x


# Default maximum number of features per document
DEFAULT_MAX_FEATURES = 8192  # 2^13: power of 2 for memory alignment


[docs] class BaseSparseEmbedder(LRUCacheMixin, BaseEstimator, ABC): """Abstract base class for sparse embedding models using scikit-learn. This class provides a unified interface for: * Training sparse embedding models (Count, BM25, TF-IDF) * Handling custom tokenization or pre-tokenized inputs * Converting scipy sparse matrices to zvec-compatible dictionaries * Saving and loading trained models * LRU caching for repeated embeddings (via LRUCacheMixin) The class supports two mutually exclusive modes: 1. **Pre-tokenized mode** (``is_pretokenized=True``): Input documents are already tokenized as lists of strings. 2. **Custom tokenizer mode** (``tokenizer=<callable>``): A user-provided function tokenizes each string document before vectorization. If neither is specified, raw strings are passed directly to the underlying scikit-learn vectorizer. Args: tokenizer (Optional[Callable]): A callable that takes a string and returns a list of tokens. is_pretokenized (bool): If True, input documents must be pre-tokenized as lists of strings. max_features (Optional[int]): Maximum number of features (non-zero elements) to retain per document. Defaults to 8192. Raises: ValueError: If both ``tokenizer`` and ``is_pretokenized=True`` are set. Example: >>> embedder = MyEmbedder(tokenizer=my_tokenize_fn) >>> embedder.fit(documents) >>> vectors = embedder.embed(["query text"]) """
[docs] def __init__( self, tokenizer: Optional[Callable] = None, is_pretokenized: bool = False, max_features: Optional[int] = DEFAULT_MAX_FEATURES, cache_size: int = 1024, preprocessing_config: Optional["NormalizationConfig"] = None, ): if tokenizer is not None and is_pretokenized: raise ValueError( "Cannot specify both tokenizer and is_pretokenized=True. " "Use either a custom tokenizer OR pre-tokenized input, not both." ) self.tokenizer = tokenizer self.is_pretokenized = is_pretokenized self.max_features = max_features self.cache_size = cache_size self.preprocessing_config = preprocessing_config self.model: Optional[csr_matrix] = None self._embed_cache: dict[str, SparseVector] = {} self._cache_lock = threading.Lock()
def _prepare_vectorizer_params(self, params: dict) -> dict: """Configure scikit-learn vectorizer parameters for the current mode. When using custom tokenization or pre-tokenized inputs, this method disables the vectorizer's built-in tokenization to prevent double processing. Args: params (dict): Original vectorizer parameters from the subclass. Returns: dict: Updated parameters with tokenization settings adjusted. """ cv_params = params.copy() # Check if HF tokenizer is configured via preprocessing_config hf_tokenizer_enabled = ( self.preprocessing_config is not None and self.preprocessing_config.tokenizer is not None ) if self.is_pretokenized or self.tokenizer or hf_tokenizer_enabled: cv_params.update( { "tokenizer": _identity, "preprocessor": _identity, "token_pattern": None, "lowercase": False, } ) return cv_params def _apply_preprocessing(self, text: str) -> str | list: """Apply preprocessing configuration to a text. Args: text: Raw text to preprocess. Returns: Preprocessed text (str) or list of tokens (list) if HF tokenizer is configured. If no preprocessing_config is set, returns the original text. Note: If the preprocessing module is not available, a warning is logged and the original text is returned. """ if self.preprocessing_config is None: return text try: from zvec_db.preprocessing.config import normalize_text return normalize_text(text, self.preprocessing_config) except ImportError: import warnings warnings.warn( "Preprocessing module not available. Install with: " "pip install 'zvec-db[preprocessing]' or pip install nltk", ImportWarning, stacklevel=2, ) return text
[docs] def preprocess(self, text: str) -> str | list: """Apply preprocessing to a text (public API). This method applies the preprocessing configuration to a single text. It is useful for preprocessing queries or documents before embedding. Args: text: Raw text to preprocess. Returns: Preprocessed text (str) or list of tokens (list) if HF tokenizer is configured. If no preprocessing_config is set, returns the original text unchanged. Example: >>> from zvec_db.embedders import BM25Embedder >>> from zvec_db.preprocessing import NormalizationConfig >>> config = NormalizationConfig.aggressive(language="french") >>> embedder = BM25Embedder(preprocessing_config=config) >>> embedder.preprocess(" CHAT MANGEAIT ") 'chat mang' >>> config = NormalizationConfig.with_hf_tokenizer("gbert-base") >>> embedder = BM25Embedder(preprocessing_config=config) >>> embedder.preprocess("Le chat mange") ['le', 'chat', 'man', '##ge'] """ return self._apply_preprocessing(text)
def _prepare_corpus(self, corpus: ExtendedList) -> ExtendedList: """Pre-process a corpus according to the current tokenization mode. This method validates and transforms input data based on the embedder's configuration: * **Pre-tokenized mode**: Validates that each document is a list. * **Custom tokenizer**: Applies the tokenizer to each string document. * **Preprocessing**: Applies normalization if preprocessing_config is set. * **Default**: Returns corpus unchanged. Args: corpus (ExtendedList): The input corpus to process. Returns: ExtendedList: The processed corpus ready for vectorization. Raises: ValueError: If document format doesn't match the configured mode. """ if self.is_pretokenized: for doc in corpus: if not isinstance(doc, list): raise ValueError( "With is_pretokenized=True each document must be a list of tokens" ) return corpus if self.tokenizer: processed = [] for doc in corpus: if not isinstance(doc, str): raise ValueError( "When a tokenizer is set documents must be strings" ) # Apply preprocessing before tokenization doc_preprocessed = self._apply_preprocessing(doc) # If preprocessing returns tokens (HF tokenizer), skip custom tokenizer if isinstance(doc_preprocessed, list): processed.append(doc_preprocessed) else: processed.append(self.tokenizer(doc_preprocessed)) return processed # No custom tokenizer: apply preprocessing to raw strings if self.preprocessing_config is not None: # _apply_preprocessing may return str|list preprocessed: list[str | list[str]] = [ self._apply_preprocessing(doc) for doc in corpus # type: ignore[arg-type] ] # If preprocessing returns tokens (HF tokenizer), corpus is already tokenized return preprocessed # type: ignore[return-value] return corpus
[docs] @abstractmethod def fit(self, corpus: ExtendedList, y=None): """Train the sparse embedding model on a corpus. Args: corpus (ExtendedList): Training documents (strings or token lists depending on configuration). y: Ignored; present for scikit-learn compatibility. Returns: self: The fitted embedder instance. """ raise NotImplementedError
def _to_zvec_dict(self, matrix: csr_matrix) -> list[SparseVector]: """Convert a scipy CSR matrix to zvec-compatible sparse dictionaries. Each row of the sparse matrix is converted to a dictionary mapping feature indices to their values. If ``max_features`` is set, only the top-k features (by value) per document are retained. Note: Keys are automatically sorted in ascending order to work around a bug in zvec/proxima where unsorted keys cause incorrect search scores (often returning 0). Args: matrix (csr_matrix): Sparse matrix from the fitted model. Returns: list[SparseVector]: List of dictionaries, one per document. Keys in each dictionary are sorted in ascending order. Raises: RuntimeError: If the input matrix is None. """ if matrix is None: raise RuntimeError("Input matrix cannot be None") n_rows = matrix.shape[0] indptr = matrix.indptr indices = matrix.indices data = matrix.data # Fast path: no max_features limit, CSR indices are already sorted if self.max_features is None: return [ { int(k): float(v) for k, v in zip( indices[indptr[i] : indptr[i + 1]], data[indptr[i] : indptr[i + 1]], ) } for i in range(n_rows) ] # With max_features: select top-k by value, then sort by index results: list[SparseVector] = [] k = self.max_features for i in range(n_rows): start, end = indptr[i], indptr[i + 1] row_indices = indices[start:end] row_data = data[start:end] n_features = len(row_indices) if n_features > k: # argpartition is O(n) vs argsort O(n log n) for top-k selection top_k_idx = np.argpartition(row_data, -k)[-k:] # Extract top-k indices and values top_indices = row_indices[top_k_idx] top_data = row_data[top_k_idx] # Sort by index (required for zvec/proxima correctness) sort_idx = top_indices.argsort() results.append( { int(k): float(v) for k, v in zip(top_indices[sort_idx], top_data[sort_idx]) } ) else: # CSR format already stores indices in ascending order results.append( {int(k): float(v) for k, v in zip(row_indices, row_data)} ) return results
[docs] def __call__( self, input_text: StrExtendedList ) -> SparseVector | list[SparseVector]: """Call shortcut that delegates to :meth:`embed`. This allows the embedder to be called like a function:: embedder = BM25Embedder() embedder.fit(documents) vector = embedder("query text") # equivalent to embedder.embed(...) Args: input_text: Single document or batch of documents. Returns: Sparse vector(s) as dictionaries. """ return self.embed(input_text)
[docs] def preprocess_input(self, input_text: StrExtendedList) -> Tuple[bool, str | list]: """Determine if input is a single document or batch, and apply tokenization. This method normalizes all input types into a list format expected by scikit-learn models, while preserving information about the original input structure to restore the correct return type. The method handles three configurations: 1. **Pre-tokenized mode**: Validates and wraps token lists. 2. **Custom tokenizer**: Applies the tokenizer to string inputs. 3. **Default**: Wraps strings without modification. Args: input_text (StrExtendedList): Input to process. Format depends on configuration: * If ``is_pretokenized=True``: ``list[str]`` (single) or ``list[list[str]]`` (batch) * If ``tokenizer`` is set: ``str`` (single) or ``list[str]`` (batch) * Default: ``str`` (single) or ``list[str]`` (batch) Returns: Tuple[bool, str | list]: A tuple containing: * ``is_single`` (bool): True if input was a single document. * ``processed_list`` (list): Data wrapped as a list for the model. Raises: ValueError: If input format doesn't match the configuration. """ if self.is_pretokenized: if not isinstance(input_text, list): raise ValueError( "With is_pretokenized=True, input must be List[str] or List[List[str]]" ) if not input_text: return False, [] is_single = isinstance(input_text[0], str) return is_single, [input_text] if is_single else input_text if self.tokenizer: if isinstance(input_text, str): # Apply preprocessing before tokenization text_preprocessed = self._apply_preprocessing(input_text) return True, [self.tokenizer(text_preprocessed)] if isinstance(input_text, list): # _apply_preprocessing returns str|list, but doc is str # Narrow type: input_text is list[str] at this point (tokenizer expects strings) processed_docs: list[str | list[str]] = [ self._apply_preprocessing(doc) # type: ignore[arg-type] for doc in input_text ] return False, [self.tokenizer(doc) for doc in processed_docs] raise ValueError("With tokenizer provided, input must be str or List[str]") # No custom tokenizer: apply preprocessing if configured if isinstance(input_text, str): return True, [self._apply_preprocessing(input_text)] if isinstance(input_text, list): return False, [ self._apply_preprocessing(doc) # type: ignore[arg-type] for doc in input_text ] raise ValueError("Input must be str or List[str]")
[docs] def fit_transform(self, X, y=None) -> csr_matrix: """Fit the model and transform the data in one step. This is a convenience method that calls :meth:`fit` followed by :meth:`transform`. It is useful for training and obtaining embeddings without storing intermediate results. Args: X: Training corpus (strings or token lists). y: Ignored; present for scikit-learn compatibility. Returns: csr_matrix: Sparse matrix of fitted and transformed data. """ del y # Unused, kept for sklearn compatibility self.fit(X) return self.transform(X)
[docs] def transform(self, input_text: StrExtendedList) -> csr_matrix: """Transform input text into a sparse feature matrix. This method follows the standard scikit-learn transformer API. It automatically handles tokenization based on the embedder's configuration before passing data to the fitted model. .. note:: The model must be fitted (via :meth:`fit` or :meth:`fit_transform`) or loaded before calling this method. Args: input_text (StrExtendedList): Single document or batch of documents. Returns: csr_matrix: Sparse feature matrix with shape ``(n_docs, n_features)``. Raises: RuntimeError: If the model has not been fitted or loaded. """ if self.model is None: raise RuntimeError("Model must be fitted or loaded before transforming.") _, processed_input = self.preprocess_input(input_text) return self.model.transform(processed_input)
[docs] def embed(self, input_text: StrExtendedList) -> SparseVector | list[SparseVector]: """Embed text into sparse vectors as dictionaries. This is the primary user-facing method for generating embeddings. Unlike :meth:`transform` which returns a scipy sparse matrix, this method returns zvec-compatible dictionaries mapping ``{feature_index: value}``. The method automatically handles both single documents and batches, returning a single dictionary for a single input or a list of dictionaries for batch input. .. note:: The model must be fitted (via :meth:`fit` or :meth:`fit_transform`) or loaded before calling this method. Args: input_text (StrExtendedList): Single document or batch of documents. Returns: SparseVector | list[SparseVector]: * Single document: ``dict[int, float]`` mapping feature indices to values. * Batch: ``list[dict[int, float]]`` with one dictionary per document. Raises: RuntimeError: If the model has not been fitted or loaded. Example: >>> embedder = BM25Embedder().fit(documents) >>> vector = embedder.embed("search query") >>> vector # {42: 0.523, 108: 0.312, ...} """ return self._embed_cached(input_text)
def _embed_cached( self, input_text: StrExtendedList ) -> SparseVector | list[SparseVector]: """Internal method with LRU cache for embed().""" if self.model is None: raise RuntimeError("Model must be fitted or loaded before embedding.") is_single, processed_input = self.preprocess_input(input_text) # Use cache for single string inputs (thread-safe via mixin) if is_single and isinstance(input_text, str): return self._cached_compute( key=input_text, compute_fn=lambda: self._compute_embedding(processed_input), ) # Batch: no caching (too many unique inputs) return self._compute_embedding(processed_input) def _compute_embedding( self, processed_input: str | list ) -> SparseVector | list[SparseVector]: """Compute embedding without caching.""" # type ignore: model is checked for None in caller (_embed_cached) sparse_matrix = self.model.transform(processed_input) # type: ignore[union-attr] dicts = self._to_zvec_dict(sparse_matrix) return dicts[0] if len(dicts) == 1 else dicts
[docs] def embed_batch( self, documents: list[str], batch_size: int = 32, show_progress: bool = False, ) -> list[SparseVector]: """Embed a large batch of documents with optional progress bar. This method is optimized for processing large corpora by embedding documents in smaller batches. It supports an optional progress bar for tracking long-running operations. Args: documents (list[str]): List of documents to embed. batch_size (int, optional): Number of documents per batch. Defaults to 32. show_progress (bool, optional): Show progress bar. Defaults to False. Returns: list[SparseVector]: List of sparse vectors, one per document. Example: >>> embedder = BM25Embedder().fit(corpus) >>> vectors = embedder.embed_batch( ... large_corpus, ... batch_size=64, ... show_progress=True ... ) Note: For single documents or small batches, use :meth:`embed` instead, which includes caching for repeated inputs. """ if not documents: return [] total = len(documents) results: list[SparseVector] = [] # Optional progress bar if show_progress: try: from tqdm import tqdm iterator = tqdm( range(0, total, batch_size), desc="Embedding", unit="batch", ) except ImportError: # tqdm not installed, fall back to simple progress iterator = range(0, total, batch_size) show_progress = False else: iterator = range(0, total, batch_size) for i in iterator: batch = documents[i : i + batch_size] batch_result = self.embed(batch) if isinstance(batch_result, list): results.extend(batch_result) else: results.append(batch_result) if show_progress: # iterator may be tqdm or range, set_postfix is on tqdm if hasattr(iterator, "set_postfix"): iterator.set_postfix( {"processed": min(i + batch_size, total), "total": total} ) return results
[docs] def save(self, path: str) -> str: """Serialize the model and tokenizer to disk. The model is saved using joblib, which efficiently handles the scikit-learn pipeline and any fitted parameters. Args: path (str): File path where the model will be saved. Returns: str: The path where the model was saved (same as input). Example: >>> embedder.fit(documents) >>> embedder.save("models/bm25_model.joblib") """ joblib.dump( { "model": self.model, "tokenizer": self.tokenizer, "preprocessing_config": self.preprocessing_config, "is_pretokenized": self.is_pretokenized, "max_features": self.max_features, "cache_size": self.cache_size, }, path, ) return path
[docs] def save_pretrained(self, path: str) -> str: """Alias for :meth:`save`. This method is provided for compatibility with common naming conventions in NLP libraries (e.g., Hugging Face Transformers). Args: path (str): File path where the model will be saved. Returns: str: The path where the model was saved. """ return self.save(path)
[docs] def load(self, path: str) -> None: """Load a serialized model and tokenizer from disk. This method restores the model state from a file previously saved with :meth:`save` or :meth:`save_pretrained`. The preprocessing configuration and other settings (is_pretokenized, max_features) are also restored. Args: path (str): File path to the serialized model. Returns: None Example: >>> embedder = BM25Embedder() >>> embedder.load("models/bm25_model.joblib") """ data = joblib.load(path) self.model = data["model"] self.tokenizer = data["tokenizer"] self.preprocessing_config = data.get("preprocessing_config") self.is_pretokenized = data.get("is_pretokenized", False) self.max_features = data.get("max_features") self.cache_size = data.get("cache_size", 1024)
[docs] def from_pretrained(self, path: str) -> None: """Alias for :meth:`load`. This method is provided for compatibility with common naming conventions in NLP libraries (e.g., Hugging Face Transformers). Args: path (str): File path to the serialized model. Returns: None """ self.load(path)