Source code for zvec_db.embedders.dense.openai

"""OpenAI-compatible API embeddings using /embeddings endpoint.

This module provides dense embedding generation using OpenAI-compatible APIs,
which works with:
- OpenAI API (text-embedding-3-small, text-embedding-3-large, etc.)
- vLLM serving open-source embedding models
- Any OpenAI-compatible API endpoint

Available Classes
-----------------
OpenAIEmbedder
    Uses the /v1/embeddings endpoint for dense vector generation.
    Supports query/passage prefixes for asymmetric embedding models.

Example Usage
-------------
.. code-block:: python

    from zvec_db.embedders.dense import OpenAIEmbedder

    # OpenAI API
    embedder = OpenAIEmbedder(
        model_name="text-embedding-3-small",
        api_key="sk-..."
    )
    vector = embedder.embed("search query")

    # vLLM local with asymmetric model (e.g., E5, GTE)
    embedder = OpenAIEmbedder(
        base_url="http://localhost:8000/v1",
        model_name="intfloat/e5-large-v2",
        query_prefix="query: ",
        passage_prefix="passage: "
    )
    query_vector = embedder.embed_query("What is machine learning?")
    doc_vector = embedder.embed_passage("ML is a subset of AI.")
"""

from __future__ import annotations

import logging
from typing import Any, Dict, List, Mapping, Optional, Union

import httpx
import numpy as np

from ...utils.retry import RetryConfig, retry_with_backoff
from ..dense.embedders import BaseDenseEmbedder

logger = logging.getLogger(__name__)


[docs] class OpenAIEmbedder(BaseDenseEmbedder): """Dense embedder using OpenAI-compatible /embeddings endpoint. This embedder uses the /v1/embeddings endpoint to compute dense vector representations of texts. It's compatible with OpenAI's embedding API format and supports batch processing. Works with: - OpenAI API (text-embedding-3-small, text-embedding-3-large, etc.) - vLLM serving open-source embedding models - Any OpenAI-compatible API endpoint Args: model (str): Model name to use. OpenAI: "text-embedding-3-small", "text-embedding-3-large" vLLM: Model name configured in vLLM base_url (str, optional): API base URL. For OpenAI: "https://api.openai.com/v1" For vLLM local: "http://localhost:8000/v1" Defaults to "https://api.openai.com/v1". api_key (Optional[str], optional): API key for authentication. Defaults to None (reads from OPENAI_API_KEY env var). dimensions (Optional[int], optional): Output embedding dimensions. Only supported by some models (e.g., text-embedding-3-small). Defaults to None (use model default). timeout (float, optional): HTTP request timeout in seconds. Defaults to 30.0. encoding_format (str, optional): Encoding format for embeddings. "float" for float32 vectors, "base64" for base64-encoded. Defaults to "float". max_batch_size (Optional[int], optional): Maximum number of texts to embed in a single batch. None means no limit. Defaults to None. truncate_prompt_tokens (Optional[int], optional): Maximum number of tokens for prompt truncation. When set, prompts exceeding this limit are truncated. By default, APIs reject prompts exceeding max_model_len unless this is set. Defaults to None (no truncation). query_prefix (str, optional): Prefix to add to query texts. Useful for asymmetric embedding models like E5, GTE, etc. Example: "query: " for E5 models. Defaults to "" (no prefix). passage_prefix (str, optional): Prefix to add to passage/document texts. Useful for asymmetric embedding models like E5, GTE, etc. Example: "passage: " for E5 models. Defaults to "" (no prefix). model_kwargs (Optional[Mapping[str, Any]], optional): Additional keyword arguments passed to the API request. Useful for options like: - user: Unique identifier for monitoring and abuse detection - extra_headers: Additional HTTP headers - extra_query_params: Additional query parameters Defaults to None (no additional kwargs). model_name (str, optional): Deprecated. Use model instead. This parameter is kept for backward compatibility. Defaults to None. max_retries (int, optional): Maximum number of retry attempts for transient failures. Set to 0 to disable retries. Defaults to 3. initial_delay (float, optional): Initial delay before first retry in seconds. Defaults to 1.0. max_delay (float, optional): Maximum delay cap in seconds. Defaults to 60.0. exponential_base (float, optional): Base for exponential backoff. Defaults to 2.0. jitter (float, optional): Random jitter factor (0.0-1.0) to avoid thundering herd. Defaults to 0.1. retry_config (Optional[RetryConfig], optional): Pre-configured retry settings. If provided, overrides individual retry parameters. Defaults to None. Example: >>> # OpenAI API >>> embedder = OpenAIEmbedder( ... model="text-embedding-3-small", ... api_key="sk-..." ... ) >>> vector = embedder.embed("search query") >>> # vLLM local >>> embedder = OpenAIEmbedder( ... base_url="http://localhost:8000/v1", ... api_key="not-needed", ... model="BAAI/bge-m3" ... ) >>> vector = embedder.embed("search query") >>> # With truncation to handle long prompts >>> embedder = OpenAIEmbedder( ... base_url="http://localhost:8000/v1", ... model="embedding", ... truncate_prompt_tokens=512 ... ) >>> # With prefixes for asymmetric models (e.g., E5, GTE) >>> embedder = OpenAIEmbedder( ... base_url="http://localhost:8000/v1", ... model="intfloat/e5-large-v2", ... query_prefix="query: ", ... passage_prefix="passage: " ... ) >>> query_vector = embedder.embed_query("What is machine learning?") >>> doc_vector = embedder.embed_passage("ML is a subset of AI.") >>> # With custom retry settings for production >>> embedder = OpenAIEmbedder( ... model="text-embedding-3-small", ... max_retries=5, ... initial_delay=2.0, ... max_delay=120.0, ... ) See Also: SentenceTransformersEmbedder: Local dense embeddings using HuggingFace models. RetryConfig: Configuration class for retry behavior. """
[docs] def __init__( self, model: str = "text-embedding-3-small", base_url: str = "https://api.openai.com/v1", api_key: Optional[str] = None, dimensions: Optional[int] = None, timeout: float = 30.0, encoding_format: str = "float", max_batch_size: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None, query_prefix: Optional[str] = None, passage_prefix: Optional[str] = None, model_kwargs: Optional[Mapping[str, Any]] = None, # Deprecated: use model instead (OpenAI API naming) model_name: Optional[str] = None, # Retry configuration max_retries: int = 3, initial_delay: float = 1.0, max_delay: float = 60.0, exponential_base: float = 2.0, jitter: float = 0.1, retry_config: Optional[RetryConfig] = None, ): # Support deprecated 'model_name' parameter for backward compatibility if model_name is not None: model = model_name # Use retry_config if provided, otherwise use individual parameters if retry_config is not None: self._retry_config = retry_config else: self._retry_config = RetryConfig( max_retries=max_retries, initial_delay=initial_delay, max_delay=max_delay, exponential_base=exponential_base, jitter=jitter, ) # Store model as private attribute for OpenAI API naming self._model = model # Pass model name to parent class super().__init__(model_name=model) self._base_url = base_url.rstrip("/") self._api_key = api_key self._dimensions = dimensions self._timeout = timeout self._encoding_format = encoding_format self._max_batch_size = max_batch_size self._truncate_prompt_tokens = truncate_prompt_tokens self._query_prefix = query_prefix or "" self._passage_prefix = passage_prefix or "" self._model_kwargs: Mapping[str, Any] = model_kwargs or {} self._embedding_dim: Optional[int] = None self._is_fitted = False
_public_names = ( "base_url", "api_key", "model", "dimensions", "timeout", "encoding_format", "max_batch_size", "truncate_prompt_tokens", "query_prefix", "passage_prefix", "model_kwargs", ) @property def model_name(self) -> str: """str: Model identifier (alias for model for backward compatibility).""" return self._model @model_name.setter def model_name(self, value: str) -> None: """Setter for model_name (backward compatibility).""" self._model = value @property def model(self) -> str: """str: Model identifier (OpenAI API naming).""" return self._model @model.setter def model(self, value: str) -> None: """Setter for model.""" self._model = value @property def base_url(self) -> str: """str: Base URL for the API.""" return self._base_url @property def api_key(self) -> Optional[str]: """Optional[str]: API key for authentication.""" return self._api_key @property def dimensions(self) -> Optional[int]: """Optional[int]: Output embedding dimensions.""" return self._dimensions @property def timeout(self) -> float: """float: HTTP request timeout in seconds.""" return self._timeout @property def encoding_format(self) -> str: """str: Encoding format for embeddings.""" return self._encoding_format @property def max_batch_size(self) -> Optional[int]: """Optional[int]: Maximum batch size for embedding.""" return self._max_batch_size @property def truncate_prompt_tokens(self) -> Optional[int]: """Optional[int]: Maximum number of tokens for prompt truncation.""" return self._truncate_prompt_tokens @property def query_prefix(self) -> str: """str: Prefix added to query texts.""" return self._query_prefix @property def passage_prefix(self) -> str: """str: Prefix added to passage/document texts.""" return self._passage_prefix @property def model_kwargs(self) -> Mapping[str, Any]: """Mapping[str, Any]: Additional kwargs passed to the API.""" return self._model_kwargs @property def embedding_dim(self) -> int: """int: Dimension of embeddings (available after fit or first embed).""" if self._embedding_dim is not None: return self._embedding_dim # Default dimensions for common models defaults = { "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, "text-embedding-ada-002": 1536, "embedding": 1024, # Default vLLM model } return defaults.get(self.model_name, 1536) @property def is_fitted(self) -> bool: """bool: Whether the embedder has been fitted.""" return self._is_fitted @retry_with_backoff( max_retries=3, initial_delay=1.0, max_delay=60.0, exponential_base=2.0, jitter=0.1, retry_on_timeout=True, ) def _call_embeddings_api( self, texts: Union[str, List[str]], prefix: Optional[str] = None, ) -> List[List[float]]: """Call the /v1/embeddings endpoint. Args: texts (Union[str, List[str]]): Single text or list of texts to embed. prefix (Optional[str], optional): Prefix to add to each text. Defaults to None (no prefix). Returns: List[List[float]]: List of embedding vectors. Raises: RuntimeError: If the API call fails after retries. httpx.TimeoutException: On timeout (if retries exhausted). httpx.ConnectError: On connection failure (if retries exhausted). httpx.HTTPStatusError: On HTTP error with retryable status (if retries exhausted). """ # Ensure texts is a list if isinstance(texts, str): texts = [texts] # Apply prefix if provided if prefix: texts = [f"{prefix}{text}" for text in texts] headers = { "Content-Type": "application/json", } if self._api_key: headers["Authorization"] = f"Bearer {self._api_key}" payload: Dict[str, Any] = { "model": self.model_name, "input": texts, "encoding_format": self._encoding_format, } if self._dimensions is not None: payload["dimensions"] = self._dimensions if self._truncate_prompt_tokens is not None: payload["truncate_prompt_tokens"] = self._truncate_prompt_tokens # Add model_kwargs to payload payload.update(self._model_kwargs) response = httpx.post( f"{self._base_url}/embeddings", headers=headers, json=payload, timeout=self._timeout, ) response.raise_for_status() data = response.json() # Extract embeddings from response # Response format: {"data": [{"index": 0, "embedding": [...]}, ...]} results = data.get("data", []) # Sort by index to ensure correct ordering sorted_results = sorted(results, key=lambda x: x.get("index", 0)) embeddings = [r.get("embedding", []) for r in sorted_results] # Store embedding dimension from first result if embeddings and self._embedding_dim is None: self._embedding_dim = len(embeddings[0]) return embeddings
[docs] def fit(self, documents: List[str]) -> "OpenAIEmbedder": """Initialize the embedder. For API-based embedder, this is a no-op as the model is pre-trained. This method exists for API compatibility. Args: documents: List of documents (not used, for API compatibility). Returns: self: For method chaining. """ self._is_fitted = True return self
[docs] def embed( self, input_text: Union[str, List[str]], prefix: Optional[str] = None, ) -> Union[np.ndarray, List[np.ndarray]]: """Embed texts into dense vectors. Args: input_text (Union[str, List[str]]): Single text or list of texts to embed. prefix (Optional[str], optional): Prefix to add to each text. Defaults to None (no prefix). Returns: Union[np.ndarray, List[np.ndarray]]: - If single text: np.ndarray of shape (embedding_dim,) - If multiple texts: List[np.ndarray] of shape (n_texts, embedding_dim) """ # Call the embeddings API with optional prefix embeddings = self._call_embeddings_api(input_text, prefix=prefix) # Convert to numpy arrays is_single = isinstance(input_text, str) np_embeddings = [np.array(emb, dtype=np.float32) for emb in embeddings] if is_single: return np_embeddings[0] return np_embeddings
[docs] def embed_query( self, query: Union[str, List[str]], ) -> Union[np.ndarray, List[np.ndarray]]: """Embed a query or list of queries with the query prefix. Args: query (Union[str, List[str]]): Single query or list of queries to embed. Returns: Union[np.ndarray, List[np.ndarray]]: - If single query: np.ndarray of shape (embedding_dim,) - If multiple queries: List[np.ndarray] of shape (n_queries, embedding_dim) """ return self.embed( query, prefix=self._query_prefix if self._query_prefix else None )
[docs] def embed_passage( self, passage: Union[str, List[str]], ) -> Union[np.ndarray, List[np.ndarray]]: """Embed a passage/document or list of passages with the passage prefix. Args: passage (Union[str, List[str]]): Single passage or list of passages to embed. Returns: Union[np.ndarray, List[np.ndarray]]: - If single passage: np.ndarray of shape (embedding_dim,) - If multiple passages: List[np.ndarray] of shape (n_passages, embedding_dim) """ return self.embed( passage, prefix=self._passage_prefix if self._passage_prefix else None )
[docs] def embed_batch( self, documents: List[str], show_progress: bool = False, prefix: Optional[str] = None, ) -> List[np.ndarray]: """Embed a batch of documents. Args: documents (List[str]): List of documents to embed. show_progress (bool, optional): Show progress bar. Not used for API-based embedding. Defaults to False. prefix (Optional[str], optional): Prefix to add to each document. Defaults to None (no prefix). Returns: List[np.ndarray]: List of embedding vectors. """ if self._max_batch_size and self._max_batch_size > 0: # Process in batches all_embeddings: List[np.ndarray] = [] for i in range(0, len(documents), self._max_batch_size): batch = documents[i : i + self._max_batch_size] batch_embeddings = self.embed(batch, prefix=prefix) if isinstance(batch_embeddings, np.ndarray): batch_embeddings = [batch_embeddings] all_embeddings.extend(batch_embeddings) return all_embeddings else: # Embed all at once return self.embed(documents, prefix=prefix) # type: ignore