"""Count-based sparse embedding using term frequencies.
This module implements simple count-based sparse embedding using scikit-learn's
CountVectorizer. It converts text documents into sparse vectors based on raw
term frequencies (count of each token).
Classes
-------
CountEmbedder
Count-based embedder wrapping scikit-learn's CountVectorizer.
Example Usage
-------------
::
from zvec_db.embedders import CountEmbedder
embedder = CountEmbedder(max_features=4096, binary=True)
embedder.fit(documents)
vector = embedder.embed("search query")
"""
from typing import TYPE_CHECKING, Callable, Optional
from sklearn.feature_extraction.text import CountVectorizer
from ..base import BaseSparseEmbedder, ExtendedList
if TYPE_CHECKING:
from zvec_db.preprocessing.config import NormalizationConfig
# Default max_features: 8192 (2^13) provides good vocabulary coverage
# while maintaining memory efficiency. This matches the base class default.
DEFAULT_MAX_FEATURES = 8192
[docs]
class CountEmbedder(BaseSparseEmbedder):
"""Count-based sparse embedder wrapping scikit-learn's ``CountVectorizer``.
This embedder converts text documents into sparse vectors based on term
frequencies (raw counts of each token). It is the simplest sparse embedding
method and serves as a foundation for more advanced techniques like BM25
and TF-IDF.
The embedder accepts raw strings or pre-tokenized input. Any keyword
arguments are forwarded to the underlying ``CountVectorizer`` after being
normalized by :meth:`BaseSparseEmbedder._prepare_vectorizer_params`.
Args:
tokenizer (Optional[Callable]): Custom tokenizer function. If provided,
it will be called on each document before vectorization.
is_pretokenized (bool): If True, input documents must already be lists
of tokens. Mutually exclusive with ``tokenizer``.
max_features (Optional[int]): Maximum number of features to retain per
document. Defaults to 8192.
preprocessing_config (Optional[NormalizationConfig]): Configuration for
automatic text preprocessing (normalization, stemming, stopwords).
If set, preprocessing is automatically applied during fit() and embed().
**count_params: Additional keyword arguments passed to
``CountVectorizer`` (e.g., ``min_df``, ``max_df``, ``ngram_range``).
Example:
>>> embedder = CountEmbedder(min_df=2, ngram_range=(1, 2))
>>> embedder.fit(documents)
>>> vectors = embedder.embed(["query text"])
"""
[docs]
def __init__(
self,
tokenizer: Optional[Callable] = None,
is_pretokenized: bool = False,
max_features: Optional[int] = DEFAULT_MAX_FEATURES,
preprocessing_config: Optional["NormalizationConfig"] = None,
**count_params,
):
super().__init__(
tokenizer,
is_pretokenized,
max_features,
preprocessing_config=preprocessing_config,
)
self.vectorizer_params = count_params
[docs]
def fit(self, corpus: ExtendedList, y=None):
"""Train the embedder on a corpus of documents.
The supplied ``corpus`` is normalised according to the instance
configuration:
* ``is_pretokenized=True`` - the caller must provide lists of tokens.
* ``tokenizer=...`` - each string in the corpus will be passed
through the tokenizer before vectorisation.
* neither set - raw strings are passed to ``CountVectorizer``
directly.
``_prepare_corpus`` handles the validation and transformation logic.
Args:
corpus: Sequence of documents (strings or token lists depending on
configuration).
Returns:
``self`` to allow chaining.
"""
# normalise corpus before fitting
processed = self._prepare_corpus(corpus)
params = self._prepare_vectorizer_params(self.vectorizer_params)
self.model = CountVectorizer(**params)
self.model.fit(processed)
self.is_fitted_ = True
return self