"""Base classes and type definitions for sparse embedding models.
This module provides the abstract base class :class:`BaseSparseEmbedder` which
defines a common interface for all sparse embedding models in this package.
It handles tokenization, model persistence, and conversion to zvec-compatible
formats.
Constants
---------
DEFAULT_MAX_FEATURES : int
Default maximum number of features (non-zero elements) to retain per
document. Set to 8192 (2^13) as a power of 2 for memory alignment,
balancing between vocabulary coverage and memory efficiency.
"""
import threading
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Optional, Tuple
import joblib
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator
from ..utils.cache import LRUCacheMixin
if TYPE_CHECKING:
from zvec_db.preprocessing.config import NormalizationConfig
# Type aliases for zvec compatibility
SparseVector = dict[int, float]
"""A sparse vector represented as a dictionary mapping feature indices to values."""
ExtendedList = list[str] | list[list[str]]
"""A corpus that can be either raw strings or pre-tokenized lists."""
StrExtendedList = str | list[str] | list[list[str]]
"""Input text that can be a single document or a batch."""
def _identity(x: str) -> str:
"""Identity function used as tokenizer/preprocessor placeholder.
Defined at module level to be pickleable for model persistence.
"""
return x
# Default maximum number of features per document
DEFAULT_MAX_FEATURES = 8192 # 2^13: power of 2 for memory alignment
[docs]
class BaseSparseEmbedder(LRUCacheMixin, BaseEstimator, ABC):
"""Abstract base class for sparse embedding models using scikit-learn.
This class provides a unified interface for:
* Training sparse embedding models (Count, BM25, TF-IDF)
* Handling custom tokenization or pre-tokenized inputs
* Converting scipy sparse matrices to zvec-compatible dictionaries
* Saving and loading trained models
* LRU caching for repeated embeddings (via LRUCacheMixin)
The class supports two mutually exclusive modes:
1. **Pre-tokenized mode** (``is_pretokenized=True``): Input documents are
already tokenized as lists of strings.
2. **Custom tokenizer mode** (``tokenizer=<callable>``): A user-provided
function tokenizes each string document before vectorization.
If neither is specified, raw strings are passed directly to the underlying
scikit-learn vectorizer.
Args:
tokenizer (Optional[Callable]): A callable that takes a string and
returns a list of tokens.
is_pretokenized (bool): If True, input documents must be pre-tokenized
as lists of strings.
max_features (Optional[int]): Maximum number of features (non-zero
elements) to retain per document. Defaults to 8192.
Raises:
ValueError: If both ``tokenizer`` and ``is_pretokenized=True`` are set.
Example:
>>> embedder = MyEmbedder(tokenizer=my_tokenize_fn)
>>> embedder.fit(documents)
>>> vectors = embedder.embed(["query text"])
"""
[docs]
def __init__(
self,
tokenizer: Optional[Callable] = None,
is_pretokenized: bool = False,
max_features: Optional[int] = DEFAULT_MAX_FEATURES,
cache_size: int = 1024,
preprocessing_config: Optional["NormalizationConfig"] = None,
):
if tokenizer is not None and is_pretokenized:
raise ValueError(
"Cannot specify both tokenizer and is_pretokenized=True. "
"Use either a custom tokenizer OR pre-tokenized input, not both."
)
self.tokenizer = tokenizer
self.is_pretokenized = is_pretokenized
self.max_features = max_features
self.cache_size = cache_size
self.preprocessing_config = preprocessing_config
self.model: Optional[csr_matrix] = None
self._embed_cache: dict[str, SparseVector] = {}
self._cache_lock = threading.Lock()
def _prepare_vectorizer_params(self, params: dict) -> dict:
"""Configure scikit-learn vectorizer parameters for the current mode.
When using custom tokenization or pre-tokenized inputs, this method
disables the vectorizer's built-in tokenization to prevent double
processing.
Args:
params (dict): Original vectorizer parameters from the subclass.
Returns:
dict: Updated parameters with tokenization settings adjusted.
"""
cv_params = params.copy()
# Check if HF tokenizer is configured via preprocessing_config
hf_tokenizer_enabled = (
self.preprocessing_config is not None
and self.preprocessing_config.tokenizer is not None
)
if self.is_pretokenized or self.tokenizer or hf_tokenizer_enabled:
cv_params.update(
{
"tokenizer": _identity,
"preprocessor": _identity,
"token_pattern": None,
"lowercase": False,
}
)
return cv_params
def _apply_preprocessing(self, text: str) -> str | list:
"""Apply preprocessing configuration to a text.
Args:
text: Raw text to preprocess.
Returns:
Preprocessed text (str) or list of tokens (list) if HF tokenizer is configured.
If no preprocessing_config is set, returns the original text.
Note:
If the preprocessing module is not available, a warning is logged
and the original text is returned.
"""
if self.preprocessing_config is None:
return text
try:
from zvec_db.preprocessing.config import normalize_text
return normalize_text(text, self.preprocessing_config)
except ImportError:
import warnings
warnings.warn(
"Preprocessing module not available. Install with: "
"pip install 'zvec-db[preprocessing]' or pip install nltk",
ImportWarning,
stacklevel=2,
)
return text
[docs]
def preprocess(self, text: str) -> str | list:
"""Apply preprocessing to a text (public API).
This method applies the preprocessing configuration to a single text.
It is useful for preprocessing queries or documents before embedding.
Args:
text: Raw text to preprocess.
Returns:
Preprocessed text (str) or list of tokens (list) if HF tokenizer is configured.
If no preprocessing_config is set, returns the original text unchanged.
Example:
>>> from zvec_db.embedders import BM25Embedder
>>> from zvec_db.preprocessing import NormalizationConfig
>>> config = NormalizationConfig.aggressive(language="french")
>>> embedder = BM25Embedder(preprocessing_config=config)
>>> embedder.preprocess(" CHAT MANGEAIT ")
'chat mang'
>>> config = NormalizationConfig.with_hf_tokenizer("gbert-base")
>>> embedder = BM25Embedder(preprocessing_config=config)
>>> embedder.preprocess("Le chat mange")
['le', 'chat', 'man', '##ge']
"""
return self._apply_preprocessing(text)
def _prepare_corpus(self, corpus: ExtendedList) -> ExtendedList:
"""Pre-process a corpus according to the current tokenization mode.
This method validates and transforms input data based on the embedder's
configuration:
* **Pre-tokenized mode**: Validates that each document is a list.
* **Custom tokenizer**: Applies the tokenizer to each string document.
* **Preprocessing**: Applies normalization if preprocessing_config is set.
* **Default**: Returns corpus unchanged.
Args:
corpus (ExtendedList): The input corpus to process.
Returns:
ExtendedList: The processed corpus ready for vectorization.
Raises:
ValueError: If document format doesn't match the configured mode.
"""
if self.is_pretokenized:
for doc in corpus:
if not isinstance(doc, list):
raise ValueError(
"With is_pretokenized=True each document must be a list of tokens"
)
return corpus
if self.tokenizer:
processed = []
for doc in corpus:
if not isinstance(doc, str):
raise ValueError(
"When a tokenizer is set documents must be strings"
)
# Apply preprocessing before tokenization
doc_preprocessed = self._apply_preprocessing(doc)
# If preprocessing returns tokens (HF tokenizer), skip custom tokenizer
if isinstance(doc_preprocessed, list):
processed.append(doc_preprocessed)
else:
processed.append(self.tokenizer(doc_preprocessed))
return processed
# No custom tokenizer: apply preprocessing to raw strings
if self.preprocessing_config is not None:
# _apply_preprocessing may return str|list
preprocessed: list[str | list[str]] = [
self._apply_preprocessing(doc) for doc in corpus # type: ignore[arg-type]
]
# If preprocessing returns tokens (HF tokenizer), corpus is already tokenized
return preprocessed # type: ignore[return-value]
return corpus
[docs]
@abstractmethod
def fit(self, corpus: ExtendedList, y=None):
"""Train the sparse embedding model on a corpus.
Args:
corpus (ExtendedList): Training documents (strings or token lists
depending on configuration).
y: Ignored; present for scikit-learn compatibility.
Returns:
self: The fitted embedder instance.
"""
raise NotImplementedError
def _to_zvec_dict(self, matrix: csr_matrix) -> list[SparseVector]:
"""Convert a scipy CSR matrix to zvec-compatible sparse dictionaries.
Each row of the sparse matrix is converted to a dictionary mapping
feature indices to their values. If ``max_features`` is set, only the
top-k features (by value) per document are retained.
Note:
Keys are automatically sorted in ascending order to work around
a bug in zvec/proxima where unsorted keys cause incorrect search
scores (often returning 0).
Args:
matrix (csr_matrix): Sparse matrix from the fitted model.
Returns:
list[SparseVector]: List of dictionaries, one per document.
Keys in each dictionary are sorted in ascending order.
Raises:
RuntimeError: If the input matrix is None.
"""
if matrix is None:
raise RuntimeError("Input matrix cannot be None")
n_rows = matrix.shape[0]
indptr = matrix.indptr
indices = matrix.indices
data = matrix.data
# Fast path: no max_features limit, CSR indices are already sorted
if self.max_features is None:
return [
{
int(k): float(v)
for k, v in zip(
indices[indptr[i] : indptr[i + 1]],
data[indptr[i] : indptr[i + 1]],
)
}
for i in range(n_rows)
]
# With max_features: select top-k by value, then sort by index
results: list[SparseVector] = []
k = self.max_features
for i in range(n_rows):
start, end = indptr[i], indptr[i + 1]
row_indices = indices[start:end]
row_data = data[start:end]
n_features = len(row_indices)
if n_features > k:
# argpartition is O(n) vs argsort O(n log n) for top-k selection
top_k_idx = np.argpartition(row_data, -k)[-k:]
# Extract top-k indices and values
top_indices = row_indices[top_k_idx]
top_data = row_data[top_k_idx]
# Sort by index (required for zvec/proxima correctness)
sort_idx = top_indices.argsort()
results.append(
{
int(k): float(v)
for k, v in zip(top_indices[sort_idx], top_data[sort_idx])
}
)
else:
# CSR format already stores indices in ascending order
results.append(
{int(k): float(v) for k, v in zip(row_indices, row_data)}
)
return results
[docs]
def __call__(
self, input_text: StrExtendedList
) -> SparseVector | list[SparseVector]:
"""Call shortcut that delegates to :meth:`embed`.
This allows the embedder to be called like a function::
embedder = BM25Embedder()
embedder.fit(documents)
vector = embedder("query text") # equivalent to embedder.embed(...)
Args:
input_text: Single document or batch of documents.
Returns:
Sparse vector(s) as dictionaries.
"""
return self.embed(input_text)
[docs]
def embed(self, input_text: StrExtendedList) -> SparseVector | list[SparseVector]:
"""Embed text into sparse vectors as dictionaries.
This is the primary user-facing method for generating embeddings. Unlike
:meth:`transform` which returns a scipy sparse matrix, this method
returns zvec-compatible dictionaries mapping ``{feature_index: value}``.
The method automatically handles both single documents and batches,
returning a single dictionary for a single input or a list of
dictionaries for batch input.
.. note::
The model must be fitted (via :meth:`fit` or :meth:`fit_transform`)
or loaded before calling this method.
Args:
input_text (StrExtendedList): Single document or batch of documents.
Returns:
SparseVector | list[SparseVector]:
* Single document: ``dict[int, float]`` mapping feature indices
to values.
* Batch: ``list[dict[int, float]]`` with one dictionary per
document.
Raises:
RuntimeError: If the model has not been fitted or loaded.
Example:
>>> embedder = BM25Embedder().fit(documents)
>>> vector = embedder.embed("search query")
>>> vector # {42: 0.523, 108: 0.312, ...}
"""
return self._embed_cached(input_text)
def _embed_cached(
self, input_text: StrExtendedList
) -> SparseVector | list[SparseVector]:
"""Internal method with LRU cache for embed()."""
if self.model is None:
raise RuntimeError("Model must be fitted or loaded before embedding.")
is_single, processed_input = self.preprocess_input(input_text)
# Use cache for single string inputs (thread-safe via mixin)
if is_single and isinstance(input_text, str):
return self._cached_compute(
key=input_text,
compute_fn=lambda: self._compute_embedding(processed_input),
)
# Batch: no caching (too many unique inputs)
return self._compute_embedding(processed_input)
def _compute_embedding(
self, processed_input: str | list
) -> SparseVector | list[SparseVector]:
"""Compute embedding without caching."""
# type ignore: model is checked for None in caller (_embed_cached)
sparse_matrix = self.model.transform(processed_input) # type: ignore[union-attr]
dicts = self._to_zvec_dict(sparse_matrix)
return dicts[0] if len(dicts) == 1 else dicts
[docs]
def embed_batch(
self,
documents: list[str],
batch_size: int = 32,
show_progress: bool = False,
) -> list[SparseVector]:
"""Embed a large batch of documents with optional progress bar.
This method is optimized for processing large corpora by embedding
documents in smaller batches. It supports an optional progress bar
for tracking long-running operations.
Args:
documents (list[str]): List of documents to embed.
batch_size (int, optional): Number of documents per batch.
Defaults to 32.
show_progress (bool, optional): Show progress bar. Defaults to False.
Returns:
list[SparseVector]: List of sparse vectors, one per document.
Example:
>>> embedder = BM25Embedder().fit(corpus)
>>> vectors = embedder.embed_batch(
... large_corpus,
... batch_size=64,
... show_progress=True
... )
Note:
For single documents or small batches, use :meth:`embed` instead,
which includes caching for repeated inputs.
"""
if not documents:
return []
total = len(documents)
results: list[SparseVector] = []
# Optional progress bar
if show_progress:
try:
from tqdm import tqdm
iterator = tqdm(
range(0, total, batch_size),
desc="Embedding",
unit="batch",
)
except ImportError:
# tqdm not installed, fall back to simple progress
iterator = range(0, total, batch_size)
show_progress = False
else:
iterator = range(0, total, batch_size)
for i in iterator:
batch = documents[i : i + batch_size]
batch_result = self.embed(batch)
if isinstance(batch_result, list):
results.extend(batch_result)
else:
results.append(batch_result)
if show_progress:
# iterator may be tqdm or range, set_postfix is on tqdm
if hasattr(iterator, "set_postfix"):
iterator.set_postfix(
{"processed": min(i + batch_size, total), "total": total}
)
return results
[docs]
def save(self, path: str) -> str:
"""Serialize the model and tokenizer to disk.
The model is saved using joblib, which efficiently handles the
scikit-learn pipeline and any fitted parameters.
Args:
path (str): File path where the model will be saved.
Returns:
str: The path where the model was saved (same as input).
Example:
>>> embedder.fit(documents)
>>> embedder.save("models/bm25_model.joblib")
"""
joblib.dump(
{
"model": self.model,
"tokenizer": self.tokenizer,
"preprocessing_config": self.preprocessing_config,
"is_pretokenized": self.is_pretokenized,
"max_features": self.max_features,
"cache_size": self.cache_size,
},
path,
)
return path
[docs]
def save_pretrained(self, path: str) -> str:
"""Alias for :meth:`save`.
This method is provided for compatibility with common naming conventions
in NLP libraries (e.g., Hugging Face Transformers).
Args:
path (str): File path where the model will be saved.
Returns:
str: The path where the model was saved.
"""
return self.save(path)
[docs]
def load(self, path: str) -> None:
"""Load a serialized model and tokenizer from disk.
This method restores the model state from a file previously saved with
:meth:`save` or :meth:`save_pretrained`. The preprocessing configuration
and other settings (is_pretokenized, max_features) are also restored.
Args:
path (str): File path to the serialized model.
Returns:
None
Example:
>>> embedder = BM25Embedder()
>>> embedder.load("models/bm25_model.joblib")
"""
data = joblib.load(path)
self.model = data["model"]
self.tokenizer = data["tokenizer"]
self.preprocessing_config = data.get("preprocessing_config")
self.is_pretokenized = data.get("is_pretokenized", False)
self.max_features = data.get("max_features")
self.cache_size = data.get("cache_size", 1024)
[docs]
def from_pretrained(self, path: str) -> None:
"""Alias for :meth:`load`.
This method is provided for compatibility with common naming conventions
in NLP libraries (e.g., Hugging Face Transformers).
Args:
path (str): File path to the serialized model.
Returns:
None
"""
self.load(path)