"""OpenAI-compatible API embeddings using /embeddings endpoint.
This module provides dense embedding generation using OpenAI-compatible APIs,
which works with:
- OpenAI API (text-embedding-3-small, text-embedding-3-large, etc.)
- vLLM serving open-source embedding models
- Any OpenAI-compatible API endpoint
Available Classes
-----------------
OpenAIEmbedder
Uses the /v1/embeddings endpoint for dense vector generation.
Supports query/passage prefixes for asymmetric embedding models.
Example Usage
-------------
.. code-block:: python
from zvec_db.embedders.dense import OpenAIEmbedder
# OpenAI API
embedder = OpenAIEmbedder(
model_name="text-embedding-3-small",
api_key="sk-..."
)
vector = embedder.embed("search query")
# vLLM local with asymmetric model (e.g., E5, GTE)
embedder = OpenAIEmbedder(
base_url="http://localhost:8000/v1",
model_name="intfloat/e5-large-v2",
query_prefix="query: ",
passage_prefix="passage: "
)
query_vector = embedder.embed_query("What is machine learning?")
doc_vector = embedder.embed_passage("ML is a subset of AI.")
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Mapping, Optional, Union
import httpx
import numpy as np
from ...utils.retry import RetryConfig, retry_with_backoff
from ..dense.embedders import BaseDenseEmbedder
logger = logging.getLogger(__name__)
[docs]
class OpenAIEmbedder(BaseDenseEmbedder):
"""Dense embedder using OpenAI-compatible /embeddings endpoint.
This embedder uses the /v1/embeddings endpoint to compute dense
vector representations of texts. It's compatible with OpenAI's embedding
API format and supports batch processing.
Works with:
- OpenAI API (text-embedding-3-small, text-embedding-3-large, etc.)
- vLLM serving open-source embedding models
- Any OpenAI-compatible API endpoint
Args:
model (str): Model name to use.
OpenAI: "text-embedding-3-small", "text-embedding-3-large"
vLLM: Model name configured in vLLM
base_url (str, optional): API base URL.
For OpenAI: "https://api.openai.com/v1"
For vLLM local: "http://localhost:8000/v1"
Defaults to "https://api.openai.com/v1".
api_key (Optional[str], optional): API key for authentication.
Defaults to None (reads from OPENAI_API_KEY env var).
dimensions (Optional[int], optional): Output embedding dimensions.
Only supported by some models (e.g., text-embedding-3-small).
Defaults to None (use model default).
timeout (float, optional): HTTP request timeout in seconds.
Defaults to 30.0.
encoding_format (str, optional): Encoding format for embeddings.
"float" for float32 vectors, "base64" for base64-encoded.
Defaults to "float".
max_batch_size (Optional[int], optional): Maximum number of texts
to embed in a single batch. None means no limit.
Defaults to None.
truncate_prompt_tokens (Optional[int], optional): Maximum number of tokens
for prompt truncation. When set, prompts exceeding this limit are truncated.
By default, APIs reject prompts exceeding max_model_len unless this is set.
Defaults to None (no truncation).
query_prefix (str, optional): Prefix to add to query texts.
Useful for asymmetric embedding models like E5, GTE, etc.
Example: "query: " for E5 models.
Defaults to "" (no prefix).
passage_prefix (str, optional): Prefix to add to passage/document texts.
Useful for asymmetric embedding models like E5, GTE, etc.
Example: "passage: " for E5 models.
Defaults to "" (no prefix).
model_kwargs (Optional[Mapping[str, Any]], optional): Additional keyword arguments
passed to the API request. Useful for options like:
- user: Unique identifier for monitoring and abuse detection
- extra_headers: Additional HTTP headers
- extra_query_params: Additional query parameters
Defaults to None (no additional kwargs).
model_name (str, optional): Deprecated. Use model instead.
This parameter is kept for backward compatibility.
Defaults to None.
max_retries (int, optional): Maximum number of retry attempts for transient failures.
Set to 0 to disable retries. Defaults to 3.
initial_delay (float, optional): Initial delay before first retry in seconds.
Defaults to 1.0.
max_delay (float, optional): Maximum delay cap in seconds. Defaults to 60.0.
exponential_base (float, optional): Base for exponential backoff. Defaults to 2.0.
jitter (float, optional): Random jitter factor (0.0-1.0) to avoid thundering herd.
Defaults to 0.1.
retry_config (Optional[RetryConfig], optional): Pre-configured retry settings.
If provided, overrides individual retry parameters. Defaults to None.
Example:
>>> # OpenAI API
>>> embedder = OpenAIEmbedder(
... model="text-embedding-3-small",
... api_key="sk-..."
... )
>>> vector = embedder.embed("search query")
>>> # vLLM local
>>> embedder = OpenAIEmbedder(
... base_url="http://localhost:8000/v1",
... api_key="not-needed",
... model="BAAI/bge-m3"
... )
>>> vector = embedder.embed("search query")
>>> # With truncation to handle long prompts
>>> embedder = OpenAIEmbedder(
... base_url="http://localhost:8000/v1",
... model="embedding",
... truncate_prompt_tokens=512
... )
>>> # With prefixes for asymmetric models (e.g., E5, GTE)
>>> embedder = OpenAIEmbedder(
... base_url="http://localhost:8000/v1",
... model="intfloat/e5-large-v2",
... query_prefix="query: ",
... passage_prefix="passage: "
... )
>>> query_vector = embedder.embed_query("What is machine learning?")
>>> doc_vector = embedder.embed_passage("ML is a subset of AI.")
>>> # With custom retry settings for production
>>> embedder = OpenAIEmbedder(
... model="text-embedding-3-small",
... max_retries=5,
... initial_delay=2.0,
... max_delay=120.0,
... )
See Also:
SentenceTransformersEmbedder: Local dense embeddings using HuggingFace models.
RetryConfig: Configuration class for retry behavior.
"""
[docs]
def __init__(
self,
model: str = "text-embedding-3-small",
base_url: str = "https://api.openai.com/v1",
api_key: Optional[str] = None,
dimensions: Optional[int] = None,
timeout: float = 30.0,
encoding_format: str = "float",
max_batch_size: Optional[int] = None,
truncate_prompt_tokens: Optional[int] = None,
query_prefix: Optional[str] = None,
passage_prefix: Optional[str] = None,
model_kwargs: Optional[Mapping[str, Any]] = None,
# Deprecated: use model instead (OpenAI API naming)
model_name: Optional[str] = None,
# Retry configuration
max_retries: int = 3,
initial_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
jitter: float = 0.1,
retry_config: Optional[RetryConfig] = None,
):
# Support deprecated 'model_name' parameter for backward compatibility
if model_name is not None:
model = model_name
# Use retry_config if provided, otherwise use individual parameters
if retry_config is not None:
self._retry_config = retry_config
else:
self._retry_config = RetryConfig(
max_retries=max_retries,
initial_delay=initial_delay,
max_delay=max_delay,
exponential_base=exponential_base,
jitter=jitter,
)
# Store model as private attribute for OpenAI API naming
self._model = model
# Pass model name to parent class
super().__init__(model_name=model)
self._base_url = base_url.rstrip("/")
self._api_key = api_key
self._dimensions = dimensions
self._timeout = timeout
self._encoding_format = encoding_format
self._max_batch_size = max_batch_size
self._truncate_prompt_tokens = truncate_prompt_tokens
self._query_prefix = query_prefix or ""
self._passage_prefix = passage_prefix or ""
self._model_kwargs: Mapping[str, Any] = model_kwargs or {}
self._embedding_dim: Optional[int] = None
self._is_fitted = False
_public_names = (
"base_url",
"api_key",
"model",
"dimensions",
"timeout",
"encoding_format",
"max_batch_size",
"truncate_prompt_tokens",
"query_prefix",
"passage_prefix",
"model_kwargs",
)
@property
def model_name(self) -> str:
"""str: Model identifier (alias for model for backward compatibility)."""
return self._model
@model_name.setter
def model_name(self, value: str) -> None:
"""Setter for model_name (backward compatibility)."""
self._model = value
@property
def model(self) -> str:
"""str: Model identifier (OpenAI API naming)."""
return self._model
@model.setter
def model(self, value: str) -> None:
"""Setter for model."""
self._model = value
@property
def base_url(self) -> str:
"""str: Base URL for the API."""
return self._base_url
@property
def api_key(self) -> Optional[str]:
"""Optional[str]: API key for authentication."""
return self._api_key
@property
def dimensions(self) -> Optional[int]:
"""Optional[int]: Output embedding dimensions."""
return self._dimensions
@property
def timeout(self) -> float:
"""float: HTTP request timeout in seconds."""
return self._timeout
@property
def encoding_format(self) -> str:
"""str: Encoding format for embeddings."""
return self._encoding_format
@property
def max_batch_size(self) -> Optional[int]:
"""Optional[int]: Maximum batch size for embedding."""
return self._max_batch_size
@property
def truncate_prompt_tokens(self) -> Optional[int]:
"""Optional[int]: Maximum number of tokens for prompt truncation."""
return self._truncate_prompt_tokens
@property
def query_prefix(self) -> str:
"""str: Prefix added to query texts."""
return self._query_prefix
@property
def passage_prefix(self) -> str:
"""str: Prefix added to passage/document texts."""
return self._passage_prefix
@property
def model_kwargs(self) -> Mapping[str, Any]:
"""Mapping[str, Any]: Additional kwargs passed to the API."""
return self._model_kwargs
@property
def embedding_dim(self) -> int:
"""int: Dimension of embeddings (available after fit or first embed)."""
if self._embedding_dim is not None:
return self._embedding_dim
# Default dimensions for common models
defaults = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
"embedding": 1024, # Default vLLM model
}
return defaults.get(self.model_name, 1536)
@property
def is_fitted(self) -> bool:
"""bool: Whether the embedder has been fitted."""
return self._is_fitted
@retry_with_backoff(
max_retries=3,
initial_delay=1.0,
max_delay=60.0,
exponential_base=2.0,
jitter=0.1,
retry_on_timeout=True,
)
def _call_embeddings_api(
self,
texts: Union[str, List[str]],
prefix: Optional[str] = None,
) -> List[List[float]]:
"""Call the /v1/embeddings endpoint.
Args:
texts (Union[str, List[str]]): Single text or list of texts to embed.
prefix (Optional[str], optional): Prefix to add to each text.
Defaults to None (no prefix).
Returns:
List[List[float]]: List of embedding vectors.
Raises:
RuntimeError: If the API call fails after retries.
httpx.TimeoutException: On timeout (if retries exhausted).
httpx.ConnectError: On connection failure (if retries exhausted).
httpx.HTTPStatusError: On HTTP error with retryable status (if retries exhausted).
"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
# Apply prefix if provided
if prefix:
texts = [f"{prefix}{text}" for text in texts]
headers = {
"Content-Type": "application/json",
}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
payload: Dict[str, Any] = {
"model": self.model_name,
"input": texts,
"encoding_format": self._encoding_format,
}
if self._dimensions is not None:
payload["dimensions"] = self._dimensions
if self._truncate_prompt_tokens is not None:
payload["truncate_prompt_tokens"] = self._truncate_prompt_tokens
# Add model_kwargs to payload
payload.update(self._model_kwargs)
response = httpx.post(
f"{self._base_url}/embeddings",
headers=headers,
json=payload,
timeout=self._timeout,
)
response.raise_for_status()
data = response.json()
# Extract embeddings from response
# Response format: {"data": [{"index": 0, "embedding": [...]}, ...]}
results = data.get("data", [])
# Sort by index to ensure correct ordering
sorted_results = sorted(results, key=lambda x: x.get("index", 0))
embeddings = [r.get("embedding", []) for r in sorted_results]
# Store embedding dimension from first result
if embeddings and self._embedding_dim is None:
self._embedding_dim = len(embeddings[0])
return embeddings
[docs]
def fit(self, documents: List[str]) -> "OpenAIEmbedder":
"""Initialize the embedder.
For API-based embedder, this is a no-op as the model is pre-trained.
This method exists for API compatibility.
Args:
documents: List of documents (not used, for API compatibility).
Returns:
self: For method chaining.
"""
self._is_fitted = True
return self
[docs]
def embed(
self,
input_text: Union[str, List[str]],
prefix: Optional[str] = None,
) -> Union[np.ndarray, List[np.ndarray]]:
"""Embed texts into dense vectors.
Args:
input_text (Union[str, List[str]]): Single text or list of texts to embed.
prefix (Optional[str], optional): Prefix to add to each text.
Defaults to None (no prefix).
Returns:
Union[np.ndarray, List[np.ndarray]]:
- If single text: np.ndarray of shape (embedding_dim,)
- If multiple texts: List[np.ndarray] of shape (n_texts, embedding_dim)
"""
# Call the embeddings API with optional prefix
embeddings = self._call_embeddings_api(input_text, prefix=prefix)
# Convert to numpy arrays
is_single = isinstance(input_text, str)
np_embeddings = [np.array(emb, dtype=np.float32) for emb in embeddings]
if is_single:
return np_embeddings[0]
return np_embeddings
[docs]
def embed_query(
self,
query: Union[str, List[str]],
) -> Union[np.ndarray, List[np.ndarray]]:
"""Embed a query or list of queries with the query prefix.
Args:
query (Union[str, List[str]]): Single query or list of queries to embed.
Returns:
Union[np.ndarray, List[np.ndarray]]:
- If single query: np.ndarray of shape (embedding_dim,)
- If multiple queries: List[np.ndarray] of shape (n_queries, embedding_dim)
"""
return self.embed(
query, prefix=self._query_prefix if self._query_prefix else None
)
[docs]
def embed_passage(
self,
passage: Union[str, List[str]],
) -> Union[np.ndarray, List[np.ndarray]]:
"""Embed a passage/document or list of passages with the passage prefix.
Args:
passage (Union[str, List[str]]): Single passage or list of passages to embed.
Returns:
Union[np.ndarray, List[np.ndarray]]:
- If single passage: np.ndarray of shape (embedding_dim,)
- If multiple passages: List[np.ndarray] of shape (n_passages, embedding_dim)
"""
return self.embed(
passage, prefix=self._passage_prefix if self._passage_prefix else None
)
[docs]
def embed_batch(
self,
documents: List[str],
show_progress: bool = False,
prefix: Optional[str] = None,
) -> List[np.ndarray]:
"""Embed a batch of documents.
Args:
documents (List[str]): List of documents to embed.
show_progress (bool, optional): Show progress bar. Not used for API-based embedding.
Defaults to False.
prefix (Optional[str], optional): Prefix to add to each document.
Defaults to None (no prefix).
Returns:
List[np.ndarray]: List of embedding vectors.
"""
if self._max_batch_size and self._max_batch_size > 0:
# Process in batches
all_embeddings: List[np.ndarray] = []
for i in range(0, len(documents), self._max_batch_size):
batch = documents[i : i + self._max_batch_size]
batch_embeddings = self.embed(batch, prefix=prefix)
if isinstance(batch_embeddings, np.ndarray):
batch_embeddings = [batch_embeddings]
all_embeddings.extend(batch_embeddings)
return all_embeddings
else:
# Embed all at once
return self.embed(documents, prefix=prefix) # type: ignore