"""LRU cache mixin for embedders.
This module provides a reusable mixin class for LRU caching with thread-safe
access, used by both sparse and dense embedders.
Available Classes
-----------------
LRUCacheMixin
Mixin class providing thread-safe LRU cache for embedding computations.
Example Usage
-------------
.. code-block:: python
from zvec_db.utils.cache import LRUCacheMixin
class MyEmbedder(LRUCacheMixin):
def __init__(self, cache_size=1024):
self.cache_size = cache_size
self._embed_cache = {}
self._cache_lock = self._init_cache_lock()
def compute(self, text: str) -> list[float]:
# Use cached computation if available
return self._cached_compute(
key=text,
compute_fn=lambda: self._do_compute(text),
)
def _do_compute(self, text: str) -> list[float]:
# Actual computation logic
return [0.1, 0.2, 0.3]
"""
from __future__ import annotations
import threading
from typing import Any, Callable, TypeVar
T = TypeVar("T")
class LRUCacheMixin:
"""Mixin class providing thread-safe LRU cache for embedding computations.
This mixin encapsulates the common caching pattern used by both sparse
and dense embedders: thread-safe cache access with LRU-style eviction
when the cache reaches its maximum size.
To use this mixin, your class must have:
- `cache_size`: Maximum number of items to cache
- `_embed_cache`: Dictionary for cache storage
- `_cache_lock`: Threading lock for thread-safe access
Attributes:
cache_size (int): Maximum number of items to keep in cache.
_embed_cache (dict): Internal cache storage.
_cache_lock (threading.Lock): Lock for thread-safe cache access.
Example:
>>> class MyEmbedder(LRUCacheMixin):
... def __init__(self):
... self.cache_size = 1024
... self._embed_cache = {}
... self._cache_lock = threading.Lock()
...
... def embed(self, text: str) -> list[float]:
... return self._cached_compute(
... key=text,
... compute_fn=lambda: self._compute(text),
... )
"""
cache_size: int
_embed_cache: dict[str, Any]
_cache_lock: threading.Lock
def _init_cache_lock(self) -> threading.Lock:
"""Initialize a threading lock for cache access.
Returns:
threading.Lock: New lock instance.
"""
return threading.Lock()
def _cached_compute(
self,
key: str,
compute_fn: Callable[[], T],
) -> T:
"""Compute with LRU caching (thread-safe).
This method checks if a result is already cached for the given key.
If yes, returns the cached result. If not, calls compute_fn to compute
the result, caches it, and returns it.
The cache uses LRU-style eviction: when full, the oldest item is removed
to make room for new entries.
Args:
key: Cache key (typically the input text).
compute_fn: Function to call on cache miss to compute the result.
Returns:
Computed result (cached or fresh).
Note:
Cache access is thread-safe via a lock. The compute_fn is called
outside the lock to avoid holding it during long computations.
"""
# Check cache (with lock)
with self._cache_lock:
if key in self._embed_cache:
return self._embed_cache[key]
# Cache miss: mark for computation
cache_miss = True
# Compute outside lock (avoid holding lock during long computation)
if cache_miss:
result = compute_fn()
# Store in cache (with lock)
with self._cache_lock:
# LRU eviction: remove oldest if at capacity
if len(self._embed_cache) >= self.cache_size:
self._embed_cache.pop(next(iter(self._embed_cache)))
self._embed_cache[key] = result
return result
# Should not reach here, but just in case
return compute_fn()
def clear_cache(self) -> None:
"""Clear the embedding cache.
This method removes all cached entries, freeing memory. Useful when
you want to force recomputation of all embeddings.
Note:
This method is thread-safe.
"""
with self._cache_lock:
self._embed_cache.clear()
def cache_info(self) -> dict[str, Any]:
"""Get cache statistics.
Returns:
Dictionary with cache statistics:
- size: Current number of cached items
- max_size: Maximum cache capacity
- utilization: Current utilization as percentage (0-100)
"""
with self._cache_lock:
current_size = len(self._embed_cache)
utilization = (
(current_size / self.cache_size * 100) if self.cache_size > 0 else 0
)
return {
"size": current_size,
"max_size": self.cache_size,
"utilization": round(utilization, 2),
}