"""Dense embedding base classes.
This module provides the base class for dense embedding models.
Main Classes
------------
BaseDenseEmbedder
Abstract base class for all dense embedders.
Note:
Concrete implementations are in separate modules:
- openai.py: OpenAIEmbedder (API-based)
- sentence_transformers.py: SentenceTransformersEmbedder (local)
Example Usage
-------------
::
from zvec_db.embedders.dense import SentenceTransformersEmbedder, OpenAIEmbedder
# Sentence Transformers embedding
embedder = SentenceTransformersEmbedder(
model_name="all-MiniLM-L6-v2",
device="cpu"
)
vector = embedder.embed("search query")
# OpenAI embedding
embedder = OpenAIEmbedder(
base_url="http://localhost:8000/v1",
model="BAAI/bge-m3"
)
# fit() not needed for API-based embedders
vector = embedder.embed("search query")
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional, Union
import numpy as np
[docs]
class BaseDenseEmbedder(ABC):
"""Base class for dense embedding models.
Dense embedders generate fixed-size vector representations of text,
as opposed to sparse embeddings which have variable dimensions.
Args:
model_name (str): Name or path of the model to use.
max_length (Optional[int]): Maximum sequence length. Defaults to 512.
normalize (bool): Whether to normalize embeddings to unit length.
Defaults to True for cosine similarity compatibility.
Example:
>>> embedder = SentenceTransformersEmbedder("all-MiniLM-L6-v2")
>>> embedder.fit(["document 1", "document 2"])
>>> vector = embedder.embed("query")
>>> len(vector) # Fixed size
384
"""
[docs]
def __init__(
self,
model_name: str,
max_length: Optional[int] = 512,
normalize: bool = True,
):
self.model_name = model_name
self.max_length = max_length
self.normalize = normalize
self._fitted = False
[docs]
@abstractmethod
def fit(self, documents: List[str]) -> "BaseDenseEmbedder":
"""Initialize the embedder on a corpus.
For dense models, this is typically optional and just initializes
the model. Unlike sparse models, no vocabulary is learned.
Args:
documents: List of documents for initialization.
Returns:
self: For method chaining.
"""
...
[docs]
@abstractmethod
def embed(
self, input_text: Union[str, List[str]]
) -> Union[np.ndarray, List[np.ndarray]]:
"""Generate dense embeddings for text.
Args:
input_text: Single document or batch of documents.
Returns:
Numpy array for single input, or list of numpy arrays for batch input.
"""
...
[docs]
def __call__(
self, input_text: Union[str, List[str]]
) -> Union[np.ndarray, List[np.ndarray]]:
"""Call shortcut that delegates to :meth:`embed`.
This allows the embedder to be called like a function::
embedder = SentenceTransformersEmbedder()
embedder.fit(documents)
vector = embedder("query text") # equivalent to embedder.embed(...)
Args:
input_text: Single document or batch of documents.
Returns:
List of floats for single input, or list of lists for batch input.
"""
return self.embed(input_text)
[docs]
def save(self, path: str) -> None:
"""Save embedder configuration.
Dense models typically don't need saving as they load
pre-trained weights. This saves configuration only.
Args:
path: Path to save configuration.
"""
import joblib
joblib.dump(
{
"model_name": self.model_name,
"max_length": self.max_length,
"normalize": self.normalize,
},
path,
)
[docs]
def load(self, path: str) -> None:
"""Load embedder configuration.
Args:
path: Path to configuration file.
"""
import joblib
config = joblib.load(path)
self.model_name = config.get("model_name", self.model_name)
self.max_length = config.get("max_length", self.max_length)
self.normalize = config.get("normalize", self.normalize)
@property
def is_fitted(self) -> bool:
"""bool: True if embedder is initialized."""
return self._fitted