Source code for zvec_db.utils.cache

"""LRU cache mixin for embedders.

This module provides a reusable mixin class for LRU caching with thread-safe
access, used by both sparse and dense embedders.

Available Classes
-----------------
LRUCacheMixin
    Mixin class providing thread-safe LRU cache for embedding computations.

Example Usage
-------------
.. code-block:: python

    from zvec_db.utils.cache import LRUCacheMixin

    class MyEmbedder(LRUCacheMixin):
        def __init__(self, cache_size=1024):
            self.cache_size = cache_size
            self._embed_cache = {}
            self._cache_lock = self._init_cache_lock()

        def compute(self, text: str) -> list[float]:
            # Use cached computation if available
            return self._cached_compute(
                key=text,
                compute_fn=lambda: self._do_compute(text),
            )

        def _do_compute(self, text: str) -> list[float]:
            # Actual computation logic
            return [0.1, 0.2, 0.3]
"""

from __future__ import annotations

import threading
from typing import Any, Callable, TypeVar

T = TypeVar("T")


class LRUCacheMixin:
    """Mixin class providing thread-safe LRU cache for embedding computations.

    This mixin encapsulates the common caching pattern used by both sparse
    and dense embedders: thread-safe cache access with LRU-style eviction
    when the cache reaches its maximum size.

    To use this mixin, your class must have:
    - `cache_size`: Maximum number of items to cache
    - `_embed_cache`: Dictionary for cache storage
    - `_cache_lock`: Threading lock for thread-safe access

    Attributes:
        cache_size (int): Maximum number of items to keep in cache.
        _embed_cache (dict): Internal cache storage.
        _cache_lock (threading.Lock): Lock for thread-safe cache access.

    Example:
        >>> class MyEmbedder(LRUCacheMixin):
        ...     def __init__(self):
        ...         self.cache_size = 1024
        ...         self._embed_cache = {}
        ...         self._cache_lock = threading.Lock()
        ...
        ...     def embed(self, text: str) -> list[float]:
        ...         return self._cached_compute(
        ...             key=text,
        ...             compute_fn=lambda: self._compute(text),
        ...         )
    """

    cache_size: int
    _embed_cache: dict[str, Any]
    _cache_lock: threading.Lock

    def _init_cache_lock(self) -> threading.Lock:
        """Initialize a threading lock for cache access.

        Returns:
            threading.Lock: New lock instance.
        """
        return threading.Lock()

    def _cached_compute(
        self,
        key: str,
        compute_fn: Callable[[], T],
    ) -> T:
        """Compute with LRU caching (thread-safe).

        This method checks if a result is already cached for the given key.
        If yes, returns the cached result. If not, calls compute_fn to compute
        the result, caches it, and returns it.

        The cache uses LRU-style eviction: when full, the oldest item is removed
        to make room for new entries.

        Args:
            key: Cache key (typically the input text).
            compute_fn: Function to call on cache miss to compute the result.

        Returns:
            Computed result (cached or fresh).

        Note:
            Cache access is thread-safe via a lock. The compute_fn is called
            outside the lock to avoid holding it during long computations.
        """
        # Check cache (with lock)
        with self._cache_lock:
            if key in self._embed_cache:
                return self._embed_cache[key]
            # Cache miss: mark for computation
            cache_miss = True

        # Compute outside lock (avoid holding lock during long computation)
        if cache_miss:
            result = compute_fn()

            # Store in cache (with lock)
            with self._cache_lock:
                # LRU eviction: remove oldest if at capacity
                if len(self._embed_cache) >= self.cache_size:
                    self._embed_cache.pop(next(iter(self._embed_cache)))
                self._embed_cache[key] = result

            return result

        # Should not reach here, but just in case
        return compute_fn()

    def clear_cache(self) -> None:
        """Clear the embedding cache.

        This method removes all cached entries, freeing memory. Useful when
        you want to force recomputation of all embeddings.

        Note:
            This method is thread-safe.
        """
        with self._cache_lock:
            self._embed_cache.clear()

    def cache_info(self) -> dict[str, Any]:
        """Get cache statistics.

        Returns:
            Dictionary with cache statistics:
            - size: Current number of cached items
            - max_size: Maximum cache capacity
            - utilization: Current utilization as percentage (0-100)
        """
        with self._cache_lock:
            current_size = len(self._embed_cache)
            utilization = (
                (current_size / self.cache_size * 100) if self.cache_size > 0 else 0
            )
            return {
                "size": current_size,
                "max_size": self.cache_size,
                "utilization": round(utilization, 2),
            }