Source code for zvec_db.embedders.dense.sentence_transformers
"""Sentence Transformers embeddings using local models.
This module provides dense embedding generation using Sentence Transformers
models from HuggingFace. These models run locally on CPU or GPU.
Available Classes
-----------------
SentenceTransformersEmbedder
Uses local Sentence Transformers models for dense vector generation.
Supports hundreds of pre-trained models from HuggingFace.
Example Usage
-------------
.. code-block:: python
from zvec_db.embedders.dense import SentenceTransformersEmbedder
# Standard embedding
embedder = SentenceTransformersEmbedder(
model_name="all-MiniLM-L6-v2",
device="cpu"
)
embedder.fit(documents)
vector = embedder.embed("search query")
# With model_kwargs for private models or custom options
embedder = SentenceTransformersEmbedder(
model_name="org/private-model",
model_kwargs={"token": "hf_...", "trust_remote_code": True}
)
# With float16 for reduced memory
import torch
embedder = SentenceTransformersEmbedder(
model_name="all-MiniLM-L6-v2",
model_kwargs={"torch_dtype": torch.float16}
)
"""
from __future__ import annotations
import logging
from typing import Any, List, Mapping, Optional
import numpy as np
from ..dense.embedders import BaseDenseEmbedder
logger = logging.getLogger(__name__)
[docs]
class SentenceTransformersEmbedder(BaseDenseEmbedder):
"""Dense embeddings using Sentence Transformers models locally.
This embedder uses pre-trained models from the sentence-transformers
library to generate semantic embeddings. It supports hundreds of
models available on HuggingFace.
Args:
model_name (str, optional): Name of the model from HuggingFace.
Examples:
- "all-MiniLM-L6-v2" (384 dims, fast)
- "all-mpnet-base-v2" (768 dims, best quality)
- "BAAI/bge-small-en-v1.5" (384 dims, good quality)
Defaults to "all-MiniLM-L6-v2".
device (Optional[str], optional): Device to run model on.
"cpu", "cuda", or None for auto-detect. Defaults to None.
max_length (Optional[int], optional): Maximum sequence length.
Defaults to 512.
normalize (bool, optional): Normalize embeddings to unit length.
Defaults to True for cosine similarity compatibility.
trust_remote_code (bool, optional): Trust remote code in model.
Defaults to False.
model_kwargs (Optional[Mapping[str, Any]], optional): Additional keyword arguments
passed to SentenceTransformer constructor. Useful for options like:
- torch_dtype: Model dtype (torch.float16, torch.bfloat16, "auto")
- trust_remote_code: Trust remote code from HuggingFace Hub
- token: HuggingFace API token for private models
- revision: Model revision to load
- cache_dir: Custom cache directory
- local_files_only: Load only local files
- attn_implementation: Attention implementation (e.g., "flash_attention_2")
Defaults to None (no additional kwargs).
Example:
>>> # Standard embedding
>>> embedder = SentenceTransformersEmbedder(
... model_name="all-MiniLM-L6-v2",
... device="cpu"
... )
>>> embedder.fit(["document 1", "document 2"])
>>> vector = embedder.embed("search query")
>>> print(vector.shape)
(384,)
>>> # With model_kwargs for private models
>>> embedder = SentenceTransformersEmbedder(
... model_name="org/private-model",
... model_kwargs={"token": "hf_..."}
... )
>>> # With float16 for reduced memory
>>> import torch
>>> embedder = SentenceTransformersEmbedder(
... model_name="all-MiniLM-L6-v2",
... model_kwargs={"torch_dtype": torch.float16}
... )
Note:
- Requires the `sentence-transformers` package
- Models are downloaded automatically on first use
- GPU acceleration available if CUDA is installed
See Also:
OpenAIEmbedder: Dense embeddings via OpenAI-compatible API.
"""
[docs]
def __init__(
self,
model_name: str = "all-MiniLM-L6-v2",
device: Optional[str] = None,
max_length: Optional[int] = 512,
normalize: bool = True,
trust_remote_code: bool = False,
model_kwargs: Optional[Mapping[str, Any]] = None,
):
super().__init__(
model_name=model_name, max_length=max_length, normalize=normalize
)
self._device = device
self._trust_remote_code = trust_remote_code
self._model_kwargs: Mapping[str, Any] = model_kwargs or {}
self._model: Optional[Any] = None
_public_names = ("device", "trust_remote_code", "model_kwargs")
@property
def device(self) -> Optional[str]:
"""Optional[str]: Device to run model on."""
return self._device
@property
def trust_remote_code(self) -> bool:
"""bool: Trust remote code in model."""
return self._trust_remote_code
@property
def model_kwargs(self) -> Mapping[str, Any]:
"""Mapping[str, Any]: Additional kwargs passed to the model."""
return self._model_kwargs
@property
def embedding_dim(self) -> int:
"""int: Dimension of the embedding vectors."""
if self._model is None:
raise RuntimeError("Model not loaded. Call fit() first.")
dim = self._model.get_sentence_embedding_dimension()
return dim if dim is not None else 0
@property
def is_fitted(self) -> bool:
"""bool: Whether the embedder has been fitted."""
return self._model is not None
def _load_model(self) -> None:
"""Load the Sentence Transformers model.
Raises:
ImportError: If the sentence-transformers package is not installed.
"""
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
raise ImportError(
"SentenceTransformersEmbedder requires the 'sentence-transformers' package. "
"Install it with: pip install sentence-transformers"
) from e
self._model = SentenceTransformer(
self.model_name,
device=self._device,
trust_remote_code=self._trust_remote_code,
**self._model_kwargs,
)
[docs]
def fit(self, documents: List[str]) -> "SentenceTransformersEmbedder":
"""Initialize the embedder by loading the model.
For Sentence Transformers, this loads the model.
No training is performed as models are pre-trained.
Args:
documents: List of documents (used for initialization only).
Returns:
self: For method chaining.
"""
if self._model is None:
self._load_model()
return self
[docs]
def embed(
self,
input_text: str | List[str],
) -> np.ndarray | list[np.ndarray]:
"""Generate embeddings for text.
Args:
input_text: Single document or batch.
Returns:
Single numpy array or list for batch.
Raises:
RuntimeError: If model loading fails.
"""
# Auto-load model if not already loaded (lazy loading)
if self._model is None:
self._load_model()
if self._model is None:
raise RuntimeError("Model not loaded. Call fit() first.")
# Handle single string
if isinstance(input_text, str):
embedding = self._model.encode(
input_text,
convert_to_numpy=True,
normalize_embeddings=self.normalize,
show_progress_bar=False,
)
return embedding
# Batch
embeddings = self._model.encode(
input_text,
convert_to_numpy=True,
normalize_embeddings=self.normalize,
show_progress_bar=False,
)
return [embeddings[i] for i in range(len(embeddings))]
[docs]
def embed_batch(
self,
documents: List[str],
batch_size: int = 32,
show_progress: bool = False,
) -> List[np.ndarray]:
"""Embed a large batch of documents with optional progress bar.
This method is optimized for processing large corpora by embedding
documents in smaller batches. It supports an optional progress bar
for tracking long-running operations.
Args:
documents (List[str]): List of documents to embed.
batch_size (int, optional): Number of documents per batch.
Defaults to 32.
show_progress (bool, optional): Show progress bar. Defaults to False.
Returns:
List[np.ndarray]: List of embedding arrays, one per document.
Example:
>>> embedder = SentenceTransformersEmbedder().fit(corpus)
>>> vectors = embedder.embed_batch(
... large_corpus,
... batch_size=64,
... show_progress=True
... )
Note:
For single documents or small batches, use :meth:`embed` instead.
"""
if not documents:
return []
# Auto-load model if not already loaded
if self._model is None:
self._load_model()
total = len(documents)
results: List[np.ndarray] = []
# Optional progress bar
if show_progress:
try:
from tqdm import tqdm
iterator = tqdm(
range(0, total, batch_size),
desc="Embedding",
unit="batch",
)
except ImportError:
# tqdm not installed, fall back to simple progress
iterator = range(0, total, batch_size)
show_progress = False
else:
iterator = range(0, total, batch_size)
for i in iterator:
batch = documents[i : i + batch_size]
# type ignore: _model is checked for None in embed()
embeddings = self._model.encode( # type: ignore[union-attr]
batch,
convert_to_numpy=True,
normalize_embeddings=self.normalize,
show_progress_bar=False,
)
for j in range(len(embeddings)):
results.append(embeddings[j])
if show_progress:
iterator.set_postfix(
{"processed": min(i + batch_size, total), "total": total}
)
return results