Files
skill-seekers-reference/src/skill_seekers/embedding/cache.py
yusyus 51787e57bc style: Fix 411 ruff lint issues (Kimi's issue #4)
Auto-fixed lint issues with ruff --fix and --unsafe-fixes:

Issue #4: Ruff Lint Issues
- Before: 447 errors (originally reported as ~5,500)
- After: 55 errors remaining
- Fixed: 411 errors (92% reduction)

Auto-fixes applied:
- 156 UP006: List/Dict → list/dict (PEP 585)
- 63 UP045: Optional[X] → X | None (PEP 604)
- 52 F401: Removed unused imports
- 52 UP035: Fixed deprecated imports
- 34 E712: True/False comparisons → not/bool()
- 17 F841: Removed unused variables
- Plus 37 other auto-fixable issues

Remaining 55 errors (non-critical):
- 39 B904: Exception chaining (best practice)
- 5 F401: Unused imports (edge cases)
- 3 SIM105: Could use contextlib.suppress
- 8 other minor style issues

These remaining issues are code quality improvements, not critical bugs.

Result: Code quality significantly improved (92% of linting issues resolved)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-02-08 12:46:38 +03:00

335 lines
8.5 KiB
Python

"""
Caching layer for embeddings.
"""
import json
import sqlite3
from pathlib import Path
from datetime import datetime, timedelta
class EmbeddingCache:
"""
SQLite-based cache for embeddings.
Stores embeddings with their text hashes to avoid regeneration.
Supports TTL (time-to-live) for cache entries.
Examples:
cache = EmbeddingCache("/path/to/cache.db")
# Store embedding
cache.set("hash123", [0.1, 0.2, 0.3], model="text-embedding-3-small")
# Retrieve embedding
embedding = cache.get("hash123")
# Check if cached
if cache.has("hash123"):
print("Embedding is cached")
"""
def __init__(self, db_path: str = ":memory:", ttl_days: int = 30):
"""
Initialize embedding cache.
Args:
db_path: Path to SQLite database (":memory:" for in-memory)
ttl_days: Time-to-live for cache entries in days
"""
self.db_path = db_path
self.ttl_days = ttl_days
# Create database directory if needed
if db_path != ":memory:":
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# Initialize database
self.conn = sqlite3.connect(db_path, check_same_thread=False)
self._init_db()
def _init_db(self):
"""Initialize database schema."""
cursor = self.conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
hash TEXT PRIMARY KEY,
embedding TEXT NOT NULL,
model TEXT NOT NULL,
dimensions INTEGER NOT NULL,
created_at TEXT NOT NULL,
accessed_at TEXT NOT NULL,
access_count INTEGER DEFAULT 1
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_model ON embeddings(model)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_created_at ON embeddings(created_at)
""")
self.conn.commit()
def set(
self,
hash_key: str,
embedding: list[float],
model: str
) -> None:
"""
Store embedding in cache.
Args:
hash_key: Hash of text+model
embedding: Embedding vector
model: Model name
"""
cursor = self.conn.cursor()
now = datetime.utcnow().isoformat()
embedding_json = json.dumps(embedding)
dimensions = len(embedding)
cursor.execute("""
INSERT OR REPLACE INTO embeddings
(hash, embedding, model, dimensions, created_at, accessed_at, access_count)
VALUES (?, ?, ?, ?, ?, ?, 1)
""", (hash_key, embedding_json, model, dimensions, now, now))
self.conn.commit()
def get(self, hash_key: str) -> list[float] | None:
"""
Retrieve embedding from cache.
Args:
hash_key: Hash of text+model
Returns:
Embedding vector if cached and not expired, None otherwise
"""
cursor = self.conn.cursor()
# Get embedding
cursor.execute("""
SELECT embedding, created_at
FROM embeddings
WHERE hash = ?
""", (hash_key,))
row = cursor.fetchone()
if not row:
return None
embedding_json, created_at = row
# Check TTL
created = datetime.fromisoformat(created_at)
if datetime.utcnow() - created > timedelta(days=self.ttl_days):
# Expired, delete and return None
self.delete(hash_key)
return None
# Update access stats
now = datetime.utcnow().isoformat()
cursor.execute("""
UPDATE embeddings
SET accessed_at = ?, access_count = access_count + 1
WHERE hash = ?
""", (now, hash_key))
self.conn.commit()
return json.loads(embedding_json)
def get_batch(self, hash_keys: list[str]) -> tuple[list[list[float] | None], list[bool]]:
"""
Retrieve multiple embeddings from cache.
Args:
hash_keys: List of hashes
Returns:
Tuple of (embeddings list, cached flags)
embeddings list contains None for cache misses
"""
embeddings = []
cached_flags = []
for hash_key in hash_keys:
embedding = self.get(hash_key)
embeddings.append(embedding)
cached_flags.append(embedding is not None)
return embeddings, cached_flags
def has(self, hash_key: str) -> bool:
"""
Check if embedding is cached and not expired.
Args:
hash_key: Hash of text+model
Returns:
True if cached and not expired, False otherwise
"""
cursor = self.conn.cursor()
cursor.execute("""
SELECT created_at
FROM embeddings
WHERE hash = ?
""", (hash_key,))
row = cursor.fetchone()
if not row:
return False
# Check TTL
created = datetime.fromisoformat(row[0])
if datetime.utcnow() - created > timedelta(days=self.ttl_days):
# Expired
self.delete(hash_key)
return False
return True
def delete(self, hash_key: str) -> None:
"""
Delete embedding from cache.
Args:
hash_key: Hash of text+model
"""
cursor = self.conn.cursor()
cursor.execute("""
DELETE FROM embeddings
WHERE hash = ?
""", (hash_key,))
self.conn.commit()
def clear(self, model: str | None = None) -> int:
"""
Clear cache entries.
Args:
model: If provided, only clear entries for this model
Returns:
Number of entries deleted
"""
cursor = self.conn.cursor()
if model:
cursor.execute("""
DELETE FROM embeddings
WHERE model = ?
""", (model,))
else:
cursor.execute("DELETE FROM embeddings")
deleted = cursor.rowcount
self.conn.commit()
return deleted
def clear_expired(self) -> int:
"""
Clear expired cache entries.
Returns:
Number of entries deleted
"""
cursor = self.conn.cursor()
cutoff = (datetime.utcnow() - timedelta(days=self.ttl_days)).isoformat()
cursor.execute("""
DELETE FROM embeddings
WHERE created_at < ?
""", (cutoff,))
deleted = cursor.rowcount
self.conn.commit()
return deleted
def size(self) -> int:
"""
Get number of cached embeddings.
Returns:
Number of cache entries
"""
cursor = self.conn.cursor()
cursor.execute("SELECT COUNT(*) FROM embeddings")
return cursor.fetchone()[0]
def stats(self) -> dict:
"""
Get cache statistics.
Returns:
Dictionary with cache stats
"""
cursor = self.conn.cursor()
# Total entries
cursor.execute("SELECT COUNT(*) FROM embeddings")
total = cursor.fetchone()[0]
# Entries by model
cursor.execute("""
SELECT model, COUNT(*)
FROM embeddings
GROUP BY model
""")
by_model = {row[0]: row[1] for row in cursor.fetchall()}
# Most accessed
cursor.execute("""
SELECT hash, model, access_count
FROM embeddings
ORDER BY access_count DESC
LIMIT 10
""")
top_accessed = [
{"hash": row[0], "model": row[1], "access_count": row[2]}
for row in cursor.fetchall()
]
# Expired entries
cutoff = (datetime.utcnow() - timedelta(days=self.ttl_days)).isoformat()
cursor.execute("""
SELECT COUNT(*)
FROM embeddings
WHERE created_at < ?
""", (cutoff,))
expired = cursor.fetchone()[0]
return {
"total": total,
"by_model": by_model,
"top_accessed": top_accessed,
"expired": expired,
"ttl_days": self.ttl_days
}
def close(self):
"""Close database connection."""
self.conn.close()
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.close()