Source code for zvec_db.embedders.sparse.count

"""Count-based sparse embedding using term frequencies.

This module implements simple count-based sparse embedding using scikit-learn's
CountVectorizer. It converts text documents into sparse vectors based on raw
term frequencies (count of each token).

Classes
-------
CountEmbedder
    Count-based embedder wrapping scikit-learn's CountVectorizer.

Example Usage
-------------
::

    from zvec_db.embedders import CountEmbedder

    embedder = CountEmbedder(max_features=4096, binary=True)
    embedder.fit(documents)
    vector = embedder.embed("search query")
"""

from typing import TYPE_CHECKING, Callable, Optional

from sklearn.feature_extraction.text import CountVectorizer

from ..base import BaseSparseEmbedder, ExtendedList

if TYPE_CHECKING:
    from zvec_db.preprocessing.config import NormalizationConfig

# Default max_features: 8192 (2^13) provides good vocabulary coverage
# while maintaining memory efficiency. This matches the base class default.
DEFAULT_MAX_FEATURES = 8192


[docs] class CountEmbedder(BaseSparseEmbedder): """Count-based sparse embedder wrapping scikit-learn's ``CountVectorizer``. This embedder converts text documents into sparse vectors based on term frequencies (raw counts of each token). It is the simplest sparse embedding method and serves as a foundation for more advanced techniques like BM25 and TF-IDF. The embedder accepts raw strings or pre-tokenized input. Any keyword arguments are forwarded to the underlying ``CountVectorizer`` after being normalized by :meth:`BaseSparseEmbedder._prepare_vectorizer_params`. Args: tokenizer (Optional[Callable]): Custom tokenizer function. If provided, it will be called on each document before vectorization. is_pretokenized (bool): If True, input documents must already be lists of tokens. Mutually exclusive with ``tokenizer``. max_features (Optional[int]): Maximum number of features to retain per document. Defaults to 8192. preprocessing_config (Optional[NormalizationConfig]): Configuration for automatic text preprocessing (normalization, stemming, stopwords). If set, preprocessing is automatically applied during fit() and embed(). **count_params: Additional keyword arguments passed to ``CountVectorizer`` (e.g., ``min_df``, ``max_df``, ``ngram_range``). Example: >>> embedder = CountEmbedder(min_df=2, ngram_range=(1, 2)) >>> embedder.fit(documents) >>> vectors = embedder.embed(["query text"]) """
[docs] def __init__( self, tokenizer: Optional[Callable] = None, is_pretokenized: bool = False, max_features: Optional[int] = DEFAULT_MAX_FEATURES, preprocessing_config: Optional["NormalizationConfig"] = None, **count_params, ): super().__init__( tokenizer, is_pretokenized, max_features, preprocessing_config=preprocessing_config, ) self.vectorizer_params = count_params
[docs] def fit(self, corpus: ExtendedList, y=None): """Train the embedder on a corpus of documents. The supplied ``corpus`` is normalised according to the instance configuration: * ``is_pretokenized=True`` - the caller must provide lists of tokens. * ``tokenizer=...`` - each string in the corpus will be passed through the tokenizer before vectorisation. * neither set - raw strings are passed to ``CountVectorizer`` directly. ``_prepare_corpus`` handles the validation and transformation logic. Args: corpus: Sequence of documents (strings or token lists depending on configuration). Returns: ``self`` to allow chaining. """ # normalise corpus before fitting processed = self._prepare_corpus(corpus) params = self._prepare_vectorizer_params(self.vectorizer_params) self.model = CountVectorizer(**params) self.model.fit(processed) self.is_fitted_ = True return self