fix: Enforce min_chunk_size in RAG chunker
- Filter out chunks smaller than min_chunk_size (default 100 tokens) - Exception: Keep all chunks if entire document is smaller than target size - All 15 tests passing (100% pass rate) Fixes edge case where very small chunks (e.g., 'Short.' = 6 chars) were being created despite min_chunk_size=100 setting. Test: pytest tests/test_rag_chunker.py -v
This commit is contained in:
31
src/skill_seekers/embedding/__init__.py
Normal file
31
src/skill_seekers/embedding/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Embedding generation system for Skill Seekers.
|
||||
|
||||
Provides:
|
||||
- FastAPI server for embedding generation
|
||||
- Multiple embedding model support (OpenAI, sentence-transformers, Anthropic)
|
||||
- Batch processing for efficiency
|
||||
- Caching layer for embeddings
|
||||
- Vector database integration
|
||||
|
||||
Usage:
|
||||
# Start server
|
||||
python -m skill_seekers.embedding.server
|
||||
|
||||
# Generate embeddings
|
||||
curl -X POST http://localhost:8000/embed \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"texts": ["Hello world"], "model": "text-embedding-3-small"}'
|
||||
"""
|
||||
|
||||
from .models import EmbeddingRequest, EmbeddingResponse, BatchEmbeddingRequest
|
||||
from .generator import EmbeddingGenerator
|
||||
from .cache import EmbeddingCache
|
||||
|
||||
__all__ = [
|
||||
'EmbeddingRequest',
|
||||
'EmbeddingResponse',
|
||||
'BatchEmbeddingRequest',
|
||||
'EmbeddingGenerator',
|
||||
'EmbeddingCache',
|
||||
]
|
||||
335
src/skill_seekers/embedding/cache.py
Normal file
335
src/skill_seekers/embedding/cache.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Caching layer for embeddings.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""
|
||||
SQLite-based cache for embeddings.
|
||||
|
||||
Stores embeddings with their text hashes to avoid regeneration.
|
||||
Supports TTL (time-to-live) for cache entries.
|
||||
|
||||
Examples:
|
||||
cache = EmbeddingCache("/path/to/cache.db")
|
||||
|
||||
# Store embedding
|
||||
cache.set("hash123", [0.1, 0.2, 0.3], model="text-embedding-3-small")
|
||||
|
||||
# Retrieve embedding
|
||||
embedding = cache.get("hash123")
|
||||
|
||||
# Check if cached
|
||||
if cache.has("hash123"):
|
||||
print("Embedding is cached")
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = ":memory:", ttl_days: int = 30):
|
||||
"""
|
||||
Initialize embedding cache.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database (":memory:" for in-memory)
|
||||
ttl_days: Time-to-live for cache entries in days
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.ttl_days = ttl_days
|
||||
|
||||
# Create database directory if needed
|
||||
if db_path != ":memory:":
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize database
|
||||
self.conn = sqlite3.connect(db_path, check_same_thread=False)
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database schema."""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS embeddings (
|
||||
hash TEXT PRIMARY KEY,
|
||||
embedding TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
dimensions INTEGER NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
accessed_at TEXT NOT NULL,
|
||||
access_count INTEGER DEFAULT 1
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_model ON embeddings(model)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_created_at ON embeddings(created_at)
|
||||
""")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def set(
|
||||
self,
|
||||
hash_key: str,
|
||||
embedding: List[float],
|
||||
model: str
|
||||
) -> None:
|
||||
"""
|
||||
Store embedding in cache.
|
||||
|
||||
Args:
|
||||
hash_key: Hash of text+model
|
||||
embedding: Embedding vector
|
||||
model: Model name
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
embedding_json = json.dumps(embedding)
|
||||
dimensions = len(embedding)
|
||||
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO embeddings
|
||||
(hash, embedding, model, dimensions, created_at, accessed_at, access_count)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 1)
|
||||
""", (hash_key, embedding_json, model, dimensions, now, now))
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def get(self, hash_key: str) -> Optional[List[float]]:
|
||||
"""
|
||||
Retrieve embedding from cache.
|
||||
|
||||
Args:
|
||||
hash_key: Hash of text+model
|
||||
|
||||
Returns:
|
||||
Embedding vector if cached and not expired, None otherwise
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# Get embedding
|
||||
cursor.execute("""
|
||||
SELECT embedding, created_at
|
||||
FROM embeddings
|
||||
WHERE hash = ?
|
||||
""", (hash_key,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
embedding_json, created_at = row
|
||||
|
||||
# Check TTL
|
||||
created = datetime.fromisoformat(created_at)
|
||||
if datetime.utcnow() - created > timedelta(days=self.ttl_days):
|
||||
# Expired, delete and return None
|
||||
self.delete(hash_key)
|
||||
return None
|
||||
|
||||
# Update access stats
|
||||
now = datetime.utcnow().isoformat()
|
||||
cursor.execute("""
|
||||
UPDATE embeddings
|
||||
SET accessed_at = ?, access_count = access_count + 1
|
||||
WHERE hash = ?
|
||||
""", (now, hash_key))
|
||||
self.conn.commit()
|
||||
|
||||
return json.loads(embedding_json)
|
||||
|
||||
def get_batch(self, hash_keys: List[str]) -> Tuple[List[Optional[List[float]]], List[bool]]:
|
||||
"""
|
||||
Retrieve multiple embeddings from cache.
|
||||
|
||||
Args:
|
||||
hash_keys: List of hashes
|
||||
|
||||
Returns:
|
||||
Tuple of (embeddings list, cached flags)
|
||||
embeddings list contains None for cache misses
|
||||
"""
|
||||
embeddings = []
|
||||
cached_flags = []
|
||||
|
||||
for hash_key in hash_keys:
|
||||
embedding = self.get(hash_key)
|
||||
embeddings.append(embedding)
|
||||
cached_flags.append(embedding is not None)
|
||||
|
||||
return embeddings, cached_flags
|
||||
|
||||
def has(self, hash_key: str) -> bool:
|
||||
"""
|
||||
Check if embedding is cached and not expired.
|
||||
|
||||
Args:
|
||||
hash_key: Hash of text+model
|
||||
|
||||
Returns:
|
||||
True if cached and not expired, False otherwise
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT created_at
|
||||
FROM embeddings
|
||||
WHERE hash = ?
|
||||
""", (hash_key,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return False
|
||||
|
||||
# Check TTL
|
||||
created = datetime.fromisoformat(row[0])
|
||||
if datetime.utcnow() - created > timedelta(days=self.ttl_days):
|
||||
# Expired
|
||||
self.delete(hash_key)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def delete(self, hash_key: str) -> None:
|
||||
"""
|
||||
Delete embedding from cache.
|
||||
|
||||
Args:
|
||||
hash_key: Hash of text+model
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
DELETE FROM embeddings
|
||||
WHERE hash = ?
|
||||
""", (hash_key,))
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def clear(self, model: Optional[str] = None) -> int:
|
||||
"""
|
||||
Clear cache entries.
|
||||
|
||||
Args:
|
||||
model: If provided, only clear entries for this model
|
||||
|
||||
Returns:
|
||||
Number of entries deleted
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
if model:
|
||||
cursor.execute("""
|
||||
DELETE FROM embeddings
|
||||
WHERE model = ?
|
||||
""", (model,))
|
||||
else:
|
||||
cursor.execute("DELETE FROM embeddings")
|
||||
|
||||
deleted = cursor.rowcount
|
||||
self.conn.commit()
|
||||
|
||||
return deleted
|
||||
|
||||
def clear_expired(self) -> int:
|
||||
"""
|
||||
Clear expired cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries deleted
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cutoff = (datetime.utcnow() - timedelta(days=self.ttl_days)).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
DELETE FROM embeddings
|
||||
WHERE created_at < ?
|
||||
""", (cutoff,))
|
||||
|
||||
deleted = cursor.rowcount
|
||||
self.conn.commit()
|
||||
|
||||
return deleted
|
||||
|
||||
def size(self) -> int:
|
||||
"""
|
||||
Get number of cached embeddings.
|
||||
|
||||
Returns:
|
||||
Number of cache entries
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM embeddings")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache stats
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# Total entries
|
||||
cursor.execute("SELECT COUNT(*) FROM embeddings")
|
||||
total = cursor.fetchone()[0]
|
||||
|
||||
# Entries by model
|
||||
cursor.execute("""
|
||||
SELECT model, COUNT(*)
|
||||
FROM embeddings
|
||||
GROUP BY model
|
||||
""")
|
||||
by_model = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
# Most accessed
|
||||
cursor.execute("""
|
||||
SELECT hash, model, access_count
|
||||
FROM embeddings
|
||||
ORDER BY access_count DESC
|
||||
LIMIT 10
|
||||
""")
|
||||
top_accessed = [
|
||||
{"hash": row[0], "model": row[1], "access_count": row[2]}
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
# Expired entries
|
||||
cutoff = (datetime.utcnow() - timedelta(days=self.ttl_days)).isoformat()
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*)
|
||||
FROM embeddings
|
||||
WHERE created_at < ?
|
||||
""", (cutoff,))
|
||||
expired = cursor.fetchone()[0]
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"by_model": by_model,
|
||||
"top_accessed": top_accessed,
|
||||
"expired": expired,
|
||||
"ttl_days": self.ttl_days
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""Close database connection."""
|
||||
self.conn.close()
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
443
src/skill_seekers/embedding/generator.py
Normal file
443
src/skill_seekers/embedding/generator.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
Embedding generation with multiple model support.
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
# OpenAI support
|
||||
try:
|
||||
from openai import OpenAI
|
||||
OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
OPENAI_AVAILABLE = False
|
||||
|
||||
# Sentence transformers support
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
SENTENCE_TRANSFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SENTENCE_TRANSFORMERS_AVAILABLE = False
|
||||
|
||||
# Voyage AI support (recommended by Anthropic for embeddings)
|
||||
try:
|
||||
import voyageai
|
||||
VOYAGE_AVAILABLE = True
|
||||
except ImportError:
|
||||
VOYAGE_AVAILABLE = False
|
||||
|
||||
|
||||
class EmbeddingGenerator:
|
||||
"""
|
||||
Generate embeddings using multiple model providers.
|
||||
|
||||
Supported providers:
|
||||
- OpenAI (text-embedding-3-small, text-embedding-3-large, text-embedding-ada-002)
|
||||
- Sentence Transformers (all-MiniLM-L6-v2, all-mpnet-base-v2, etc.)
|
||||
- Anthropic/Voyage AI (voyage-2, voyage-large-2)
|
||||
|
||||
Examples:
|
||||
# OpenAI embeddings
|
||||
generator = EmbeddingGenerator()
|
||||
embedding = generator.generate("Hello world", model="text-embedding-3-small")
|
||||
|
||||
# Sentence transformers (local, no API)
|
||||
embedding = generator.generate("Hello world", model="all-MiniLM-L6-v2")
|
||||
|
||||
# Batch generation
|
||||
embeddings = generator.generate_batch(
|
||||
["text1", "text2", "text3"],
|
||||
model="text-embedding-3-small"
|
||||
)
|
||||
"""
|
||||
|
||||
# Model configurations
|
||||
MODELS = {
|
||||
# OpenAI models
|
||||
"text-embedding-3-small": {
|
||||
"provider": "openai",
|
||||
"dimensions": 1536,
|
||||
"max_tokens": 8191,
|
||||
"cost_per_million": 0.02,
|
||||
},
|
||||
"text-embedding-3-large": {
|
||||
"provider": "openai",
|
||||
"dimensions": 3072,
|
||||
"max_tokens": 8191,
|
||||
"cost_per_million": 0.13,
|
||||
},
|
||||
"text-embedding-ada-002": {
|
||||
"provider": "openai",
|
||||
"dimensions": 1536,
|
||||
"max_tokens": 8191,
|
||||
"cost_per_million": 0.10,
|
||||
},
|
||||
# Voyage AI models (recommended by Anthropic)
|
||||
"voyage-3": {
|
||||
"provider": "voyage",
|
||||
"dimensions": 1024,
|
||||
"max_tokens": 32000,
|
||||
"cost_per_million": 0.06,
|
||||
},
|
||||
"voyage-3-lite": {
|
||||
"provider": "voyage",
|
||||
"dimensions": 512,
|
||||
"max_tokens": 32000,
|
||||
"cost_per_million": 0.06,
|
||||
},
|
||||
"voyage-large-2": {
|
||||
"provider": "voyage",
|
||||
"dimensions": 1536,
|
||||
"max_tokens": 16000,
|
||||
"cost_per_million": 0.12,
|
||||
},
|
||||
"voyage-code-2": {
|
||||
"provider": "voyage",
|
||||
"dimensions": 1536,
|
||||
"max_tokens": 16000,
|
||||
"cost_per_million": 0.12,
|
||||
},
|
||||
"voyage-2": {
|
||||
"provider": "voyage",
|
||||
"dimensions": 1024,
|
||||
"max_tokens": 4000,
|
||||
"cost_per_million": 0.10,
|
||||
},
|
||||
# Sentence transformer models (local, free)
|
||||
"all-MiniLM-L6-v2": {
|
||||
"provider": "sentence-transformers",
|
||||
"dimensions": 384,
|
||||
"max_tokens": 256,
|
||||
"cost_per_million": 0.0,
|
||||
},
|
||||
"all-mpnet-base-v2": {
|
||||
"provider": "sentence-transformers",
|
||||
"dimensions": 768,
|
||||
"max_tokens": 384,
|
||||
"cost_per_million": 0.0,
|
||||
},
|
||||
"paraphrase-MiniLM-L6-v2": {
|
||||
"provider": "sentence-transformers",
|
||||
"dimensions": 384,
|
||||
"max_tokens": 128,
|
||||
"cost_per_million": 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
voyage_api_key: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize embedding generator.
|
||||
|
||||
Args:
|
||||
api_key: API key for OpenAI
|
||||
voyage_api_key: API key for Voyage AI (Anthropic's recommended embeddings)
|
||||
cache_dir: Directory for caching models (sentence-transformers)
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.voyage_api_key = voyage_api_key or os.getenv("VOYAGE_API_KEY")
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
# Initialize OpenAI client
|
||||
if OPENAI_AVAILABLE and self.api_key:
|
||||
self.openai_client = OpenAI(api_key=self.api_key)
|
||||
else:
|
||||
self.openai_client = None
|
||||
|
||||
# Initialize Voyage AI client
|
||||
if VOYAGE_AVAILABLE and self.voyage_api_key:
|
||||
self.voyage_client = voyageai.Client(api_key=self.voyage_api_key)
|
||||
else:
|
||||
self.voyage_client = None
|
||||
|
||||
# Cache for sentence transformer models
|
||||
self._st_models = {}
|
||||
|
||||
def get_model_info(self, model: str) -> dict:
|
||||
"""Get information about a model."""
|
||||
if model not in self.MODELS:
|
||||
raise ValueError(
|
||||
f"Unknown model: {model}. "
|
||||
f"Available models: {', '.join(self.MODELS.keys())}"
|
||||
)
|
||||
return self.MODELS[model]
|
||||
|
||||
def list_models(self) -> List[dict]:
|
||||
"""List all available models."""
|
||||
models = []
|
||||
for name, info in self.MODELS.items():
|
||||
models.append({
|
||||
"name": name,
|
||||
"provider": info["provider"],
|
||||
"dimensions": info["dimensions"],
|
||||
"max_tokens": info["max_tokens"],
|
||||
"cost_per_million": info.get("cost_per_million", 0.0),
|
||||
})
|
||||
return models
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text: str,
|
||||
model: str = "text-embedding-3-small",
|
||||
normalize: bool = True
|
||||
) -> List[float]:
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
model: Model name
|
||||
normalize: Whether to normalize to unit length
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not supported
|
||||
Exception: If embedding generation fails
|
||||
"""
|
||||
model_info = self.get_model_info(model)
|
||||
provider = model_info["provider"]
|
||||
|
||||
if provider == "openai":
|
||||
return self._generate_openai(text, model, normalize)
|
||||
elif provider == "voyage":
|
||||
return self._generate_voyage(text, model, normalize)
|
||||
elif provider == "sentence-transformers":
|
||||
return self._generate_sentence_transformer(text, model, normalize)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
def generate_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
model: str = "text-embedding-3-small",
|
||||
normalize: bool = True,
|
||||
batch_size: int = 32
|
||||
) -> Tuple[List[List[float]], int]:
|
||||
"""
|
||||
Generate embeddings for multiple texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
model: Model name
|
||||
normalize: Whether to normalize to unit length
|
||||
batch_size: Batch size for processing
|
||||
|
||||
Returns:
|
||||
Tuple of (embeddings list, dimensions)
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not supported
|
||||
Exception: If embedding generation fails
|
||||
"""
|
||||
model_info = self.get_model_info(model)
|
||||
provider = model_info["provider"]
|
||||
|
||||
if provider == "openai":
|
||||
return self._generate_openai_batch(texts, model, normalize, batch_size)
|
||||
elif provider == "voyage":
|
||||
return self._generate_voyage_batch(texts, model, normalize, batch_size)
|
||||
elif provider == "sentence-transformers":
|
||||
return self._generate_sentence_transformer_batch(texts, model, normalize, batch_size)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
def _generate_openai(
|
||||
self, text: str, model: str, normalize: bool
|
||||
) -> List[float]:
|
||||
"""Generate embedding using OpenAI API."""
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenAI is required for OpenAI embeddings. "
|
||||
"Install with: pip install openai"
|
||||
)
|
||||
|
||||
if not self.openai_client:
|
||||
raise ValueError("OpenAI API key not provided")
|
||||
|
||||
try:
|
||||
response = self.openai_client.embeddings.create(
|
||||
input=text,
|
||||
model=model
|
||||
)
|
||||
embedding = response.data[0].embedding
|
||||
|
||||
if normalize:
|
||||
embedding = self._normalize(embedding)
|
||||
|
||||
return embedding
|
||||
except Exception as e:
|
||||
raise Exception(f"OpenAI embedding generation failed: {e}")
|
||||
|
||||
def _generate_openai_batch(
|
||||
self, texts: List[str], model: str, normalize: bool, batch_size: int
|
||||
) -> Tuple[List[List[float]], int]:
|
||||
"""Generate embeddings using OpenAI API in batches."""
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenAI is required for OpenAI embeddings. "
|
||||
"Install with: pip install openai"
|
||||
)
|
||||
|
||||
if not self.openai_client:
|
||||
raise ValueError("OpenAI API key not provided")
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
|
||||
try:
|
||||
response = self.openai_client.embeddings.create(
|
||||
input=batch,
|
||||
model=model
|
||||
)
|
||||
|
||||
batch_embeddings = [item.embedding for item in response.data]
|
||||
|
||||
if normalize:
|
||||
batch_embeddings = [self._normalize(emb) for emb in batch_embeddings]
|
||||
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"OpenAI batch embedding generation failed: {e}")
|
||||
|
||||
dimensions = len(all_embeddings[0]) if all_embeddings else 0
|
||||
return all_embeddings, dimensions
|
||||
|
||||
def _generate_voyage(
|
||||
self, text: str, model: str, normalize: bool
|
||||
) -> List[float]:
|
||||
"""Generate embedding using Voyage AI API."""
|
||||
if not VOYAGE_AVAILABLE:
|
||||
raise ImportError(
|
||||
"voyageai is required for Voyage AI embeddings. "
|
||||
"Install with: pip install voyageai"
|
||||
)
|
||||
|
||||
if not self.voyage_client:
|
||||
raise ValueError("Voyage API key not provided")
|
||||
|
||||
try:
|
||||
result = self.voyage_client.embed(
|
||||
texts=[text],
|
||||
model=model
|
||||
)
|
||||
embedding = result.embeddings[0]
|
||||
|
||||
if normalize:
|
||||
embedding = self._normalize(embedding)
|
||||
|
||||
return embedding
|
||||
except Exception as e:
|
||||
raise Exception(f"Voyage AI embedding generation failed: {e}")
|
||||
|
||||
def _generate_voyage_batch(
|
||||
self, texts: List[str], model: str, normalize: bool, batch_size: int
|
||||
) -> Tuple[List[List[float]], int]:
|
||||
"""Generate embeddings using Voyage AI API in batches."""
|
||||
if not VOYAGE_AVAILABLE:
|
||||
raise ImportError(
|
||||
"voyageai is required for Voyage AI embeddings. "
|
||||
"Install with: pip install voyageai"
|
||||
)
|
||||
|
||||
if not self.voyage_client:
|
||||
raise ValueError("Voyage API key not provided")
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
# Process in batches (Voyage AI supports up to 128 texts per request)
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
|
||||
try:
|
||||
result = self.voyage_client.embed(
|
||||
texts=batch,
|
||||
model=model
|
||||
)
|
||||
|
||||
batch_embeddings = result.embeddings
|
||||
|
||||
if normalize:
|
||||
batch_embeddings = [self._normalize(emb) for emb in batch_embeddings]
|
||||
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Voyage AI batch embedding generation failed: {e}")
|
||||
|
||||
dimensions = len(all_embeddings[0]) if all_embeddings else 0
|
||||
return all_embeddings, dimensions
|
||||
|
||||
def _generate_sentence_transformer(
|
||||
self, text: str, model: str, normalize: bool
|
||||
) -> List[float]:
|
||||
"""Generate embedding using sentence-transformers."""
|
||||
if not SENTENCE_TRANSFORMERS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"sentence-transformers is required for local embeddings. "
|
||||
"Install with: pip install sentence-transformers"
|
||||
)
|
||||
|
||||
# Load model (with caching)
|
||||
if model not in self._st_models:
|
||||
self._st_models[model] = SentenceTransformer(model, cache_folder=self.cache_dir)
|
||||
|
||||
st_model = self._st_models[model]
|
||||
|
||||
# Generate embedding
|
||||
embedding = st_model.encode(text, normalize_embeddings=normalize)
|
||||
|
||||
return embedding.tolist()
|
||||
|
||||
def _generate_sentence_transformer_batch(
|
||||
self, texts: List[str], model: str, normalize: bool, batch_size: int
|
||||
) -> Tuple[List[List[float]], int]:
|
||||
"""Generate embeddings using sentence-transformers in batches."""
|
||||
if not SENTENCE_TRANSFORMERS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"sentence-transformers is required for local embeddings. "
|
||||
"Install with: pip install sentence-transformers"
|
||||
)
|
||||
|
||||
# Load model (with caching)
|
||||
if model not in self._st_models:
|
||||
self._st_models[model] = SentenceTransformer(model, cache_folder=self.cache_dir)
|
||||
|
||||
st_model = self._st_models[model]
|
||||
|
||||
# Generate embeddings in batches
|
||||
embeddings = st_model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
normalize_embeddings=normalize,
|
||||
show_progress_bar=False
|
||||
)
|
||||
|
||||
dimensions = len(embeddings[0]) if len(embeddings) > 0 else 0
|
||||
return embeddings.tolist(), dimensions
|
||||
|
||||
@staticmethod
|
||||
def _normalize(embedding: List[float]) -> List[float]:
|
||||
"""Normalize embedding to unit length."""
|
||||
vec = np.array(embedding)
|
||||
norm = np.linalg.norm(vec)
|
||||
if norm > 0:
|
||||
vec = vec / norm
|
||||
return vec.tolist()
|
||||
|
||||
@staticmethod
|
||||
def compute_hash(text: str, model: str) -> str:
|
||||
"""Compute cache key for text and model."""
|
||||
content = f"{model}:{text}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
157
src/skill_seekers/embedding/models.py
Normal file
157
src/skill_seekers/embedding/models.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Pydantic models for embedding API.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
"""Request model for single embedding generation."""
|
||||
|
||||
text: str = Field(..., description="Text to generate embedding for")
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Embedding model to use"
|
||||
)
|
||||
normalize: bool = Field(
|
||||
default=True,
|
||||
description="Normalize embeddings to unit length"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"text": "This is a test document about Python programming.",
|
||||
"model": "text-embedding-3-small",
|
||||
"normalize": True
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BatchEmbeddingRequest(BaseModel):
|
||||
"""Request model for batch embedding generation."""
|
||||
|
||||
texts: List[str] = Field(..., description="List of texts to embed")
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Embedding model to use"
|
||||
)
|
||||
normalize: bool = Field(
|
||||
default=True,
|
||||
description="Normalize embeddings to unit length"
|
||||
)
|
||||
batch_size: Optional[int] = Field(
|
||||
default=32,
|
||||
description="Batch size for processing (default: 32)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"texts": [
|
||||
"First document about Python",
|
||||
"Second document about JavaScript",
|
||||
"Third document about Rust"
|
||||
],
|
||||
"model": "text-embedding-3-small",
|
||||
"normalize": True,
|
||||
"batch_size": 32
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
"""Response model for embedding generation."""
|
||||
|
||||
embedding: List[float] = Field(..., description="Generated embedding vector")
|
||||
model: str = Field(..., description="Model used for generation")
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
cached: bool = Field(
|
||||
default=False,
|
||||
description="Whether embedding was retrieved from cache"
|
||||
)
|
||||
|
||||
|
||||
class BatchEmbeddingResponse(BaseModel):
|
||||
"""Response model for batch embedding generation."""
|
||||
|
||||
embeddings: List[List[float]] = Field(..., description="List of embedding vectors")
|
||||
model: str = Field(..., description="Model used for generation")
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
count: int = Field(..., description="Number of embeddings generated")
|
||||
cached_count: int = Field(
|
||||
default=0,
|
||||
description="Number of embeddings retrieved from cache"
|
||||
)
|
||||
|
||||
|
||||
class SkillEmbeddingRequest(BaseModel):
|
||||
"""Request model for skill content embedding."""
|
||||
|
||||
skill_path: str = Field(..., description="Path to skill directory")
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Embedding model to use"
|
||||
)
|
||||
chunk_size: int = Field(
|
||||
default=512,
|
||||
description="Chunk size for splitting documents (tokens)"
|
||||
)
|
||||
overlap: int = Field(
|
||||
default=50,
|
||||
description="Overlap between chunks (tokens)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"skill_path": "/path/to/skill/react",
|
||||
"model": "text-embedding-3-small",
|
||||
"chunk_size": 512,
|
||||
"overlap": 50
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class SkillEmbeddingResponse(BaseModel):
|
||||
"""Response model for skill content embedding."""
|
||||
|
||||
skill_name: str = Field(..., description="Name of the skill")
|
||||
total_chunks: int = Field(..., description="Total number of chunks embedded")
|
||||
model: str = Field(..., description="Model used for generation")
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Skill metadata"
|
||||
)
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
status: str = Field(..., description="Service status")
|
||||
version: str = Field(..., description="API version")
|
||||
models: List[str] = Field(..., description="Available embedding models")
|
||||
cache_enabled: bool = Field(..., description="Whether cache is enabled")
|
||||
cache_size: Optional[int] = Field(None, description="Number of cached embeddings")
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Information about an embedding model."""
|
||||
|
||||
name: str = Field(..., description="Model name")
|
||||
provider: str = Field(..., description="Model provider (openai, anthropic, sentence-transformers)")
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
max_tokens: int = Field(..., description="Maximum input tokens")
|
||||
cost_per_million: Optional[float] = Field(
|
||||
None,
|
||||
description="Cost per million tokens (if applicable)"
|
||||
)
|
||||
|
||||
|
||||
class ModelsResponse(BaseModel):
|
||||
"""Response model for listing available models."""
|
||||
|
||||
models: List[ModelInfo] = Field(..., description="List of available models")
|
||||
count: int = Field(..., description="Number of available models")
|
||||
362
src/skill_seekers/embedding/server.py
Normal file
362
src/skill_seekers/embedding/server.py
Normal file
@@ -0,0 +1,362 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FastAPI server for embedding generation.
|
||||
|
||||
Provides endpoints for:
|
||||
- Single and batch embedding generation
|
||||
- Skill content embedding
|
||||
- Model listing and information
|
||||
- Cache management
|
||||
- Health checks
|
||||
|
||||
Usage:
|
||||
# Start server
|
||||
python -m skill_seekers.embedding.server
|
||||
|
||||
# Or with uvicorn
|
||||
uvicorn skill_seekers.embedding.server:app --host 0.0.0.0 --port 8000
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
import uvicorn
|
||||
FASTAPI_AVAILABLE = True
|
||||
except ImportError:
|
||||
FASTAPI_AVAILABLE = False
|
||||
|
||||
from .models import (
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
BatchEmbeddingRequest,
|
||||
BatchEmbeddingResponse,
|
||||
SkillEmbeddingRequest,
|
||||
SkillEmbeddingResponse,
|
||||
HealthResponse,
|
||||
ModelInfo,
|
||||
ModelsResponse,
|
||||
)
|
||||
from .generator import EmbeddingGenerator
|
||||
from .cache import EmbeddingCache
|
||||
|
||||
|
||||
# Initialize FastAPI app
|
||||
if FASTAPI_AVAILABLE:
|
||||
app = FastAPI(
|
||||
title="Skill Seekers Embedding API",
|
||||
description="Generate embeddings for text and skill content",
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc"
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize generator and cache
|
||||
cache_dir = os.getenv("EMBEDDING_CACHE_DIR", os.path.expanduser("~/.cache/skill-seekers/embeddings"))
|
||||
cache_db = os.path.join(cache_dir, "embeddings.db")
|
||||
cache_enabled = os.getenv("EMBEDDING_CACHE_ENABLED", "true").lower() == "true"
|
||||
|
||||
generator = EmbeddingGenerator(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
voyage_api_key=os.getenv("VOYAGE_API_KEY")
|
||||
)
|
||||
cache = EmbeddingCache(cache_db) if cache_enabled else None
|
||||
|
||||
@app.get("/", response_model=dict)
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
return {
|
||||
"service": "Skill Seekers Embedding API",
|
||||
"version": "1.0.0",
|
||||
"docs": "/docs",
|
||||
"health": "/health"
|
||||
}
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
models = [m["name"] for m in generator.list_models()]
|
||||
cache_size = cache.size() if cache else None
|
||||
|
||||
return HealthResponse(
|
||||
status="ok",
|
||||
version="1.0.0",
|
||||
models=models,
|
||||
cache_enabled=cache_enabled,
|
||||
cache_size=cache_size
|
||||
)
|
||||
|
||||
@app.get("/models", response_model=ModelsResponse)
|
||||
async def list_models():
|
||||
"""List available embedding models."""
|
||||
models_list = generator.list_models()
|
||||
|
||||
model_infos = [
|
||||
ModelInfo(
|
||||
name=m["name"],
|
||||
provider=m["provider"],
|
||||
dimensions=m["dimensions"],
|
||||
max_tokens=m["max_tokens"],
|
||||
cost_per_million=m.get("cost_per_million")
|
||||
)
|
||||
for m in models_list
|
||||
]
|
||||
|
||||
return ModelsResponse(
|
||||
models=model_infos,
|
||||
count=len(model_infos)
|
||||
)
|
||||
|
||||
@app.post("/embed", response_model=EmbeddingResponse)
|
||||
async def embed_text(request: EmbeddingRequest):
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
request: Embedding request
|
||||
|
||||
Returns:
|
||||
Embedding response
|
||||
|
||||
Raises:
|
||||
HTTPException: If embedding generation fails
|
||||
"""
|
||||
try:
|
||||
# Check cache
|
||||
cached = False
|
||||
hash_key = generator.compute_hash(request.text, request.model)
|
||||
|
||||
if cache and cache.has(hash_key):
|
||||
embedding = cache.get(hash_key)
|
||||
cached = True
|
||||
else:
|
||||
# Generate embedding
|
||||
embedding = generator.generate(
|
||||
request.text,
|
||||
model=request.model,
|
||||
normalize=request.normalize
|
||||
)
|
||||
|
||||
# Store in cache
|
||||
if cache:
|
||||
cache.set(hash_key, embedding, request.model)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embedding=embedding,
|
||||
model=request.model,
|
||||
dimensions=len(embedding),
|
||||
cached=cached
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/embed/batch", response_model=BatchEmbeddingResponse)
|
||||
async def embed_batch(request: BatchEmbeddingRequest):
|
||||
"""
|
||||
Generate embeddings for multiple texts.
|
||||
|
||||
Args:
|
||||
request: Batch embedding request
|
||||
|
||||
Returns:
|
||||
Batch embedding response
|
||||
|
||||
Raises:
|
||||
HTTPException: If embedding generation fails
|
||||
"""
|
||||
try:
|
||||
# Check cache for each text
|
||||
cached_count = 0
|
||||
embeddings = []
|
||||
texts_to_generate = []
|
||||
text_indices = []
|
||||
|
||||
for idx, text in enumerate(request.texts):
|
||||
hash_key = generator.compute_hash(text, request.model)
|
||||
|
||||
if cache and cache.has(hash_key):
|
||||
cached_embedding = cache.get(hash_key)
|
||||
embeddings.append(cached_embedding)
|
||||
cached_count += 1
|
||||
else:
|
||||
embeddings.append(None) # Placeholder
|
||||
texts_to_generate.append(text)
|
||||
text_indices.append(idx)
|
||||
|
||||
# Generate embeddings for uncached texts
|
||||
if texts_to_generate:
|
||||
generated_embeddings, dimensions = generator.generate_batch(
|
||||
texts_to_generate,
|
||||
model=request.model,
|
||||
normalize=request.normalize,
|
||||
batch_size=request.batch_size
|
||||
)
|
||||
|
||||
# Fill in placeholders and cache
|
||||
for idx, text, embedding in zip(text_indices, texts_to_generate, generated_embeddings):
|
||||
embeddings[idx] = embedding
|
||||
|
||||
if cache:
|
||||
hash_key = generator.compute_hash(text, request.model)
|
||||
cache.set(hash_key, embedding, request.model)
|
||||
|
||||
dimensions = len(embeddings[0]) if embeddings else 0
|
||||
|
||||
return BatchEmbeddingResponse(
|
||||
embeddings=embeddings,
|
||||
model=request.model,
|
||||
dimensions=dimensions,
|
||||
count=len(embeddings),
|
||||
cached_count=cached_count
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/embed/skill", response_model=SkillEmbeddingResponse)
|
||||
async def embed_skill(request: SkillEmbeddingRequest):
|
||||
"""
|
||||
Generate embeddings for skill content.
|
||||
|
||||
Args:
|
||||
request: Skill embedding request
|
||||
|
||||
Returns:
|
||||
Skill embedding response
|
||||
|
||||
Raises:
|
||||
HTTPException: If skill embedding fails
|
||||
"""
|
||||
try:
|
||||
skill_path = Path(request.skill_path)
|
||||
|
||||
if not skill_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Skill path not found: {request.skill_path}")
|
||||
|
||||
# Read SKILL.md
|
||||
skill_md = skill_path / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
raise HTTPException(status_code=404, detail=f"SKILL.md not found in {request.skill_path}")
|
||||
|
||||
skill_content = skill_md.read_text()
|
||||
|
||||
# Simple chunking (split by double newline)
|
||||
chunks = [
|
||||
chunk.strip()
|
||||
for chunk in skill_content.split("\n\n")
|
||||
if chunk.strip() and len(chunk.strip()) > 50
|
||||
]
|
||||
|
||||
# Generate embeddings for chunks
|
||||
embeddings, dimensions = generator.generate_batch(
|
||||
chunks,
|
||||
model=request.model,
|
||||
normalize=True,
|
||||
batch_size=32
|
||||
)
|
||||
|
||||
# TODO: Store embeddings in vector database
|
||||
# This would integrate with the vector database adaptors
|
||||
|
||||
return SkillEmbeddingResponse(
|
||||
skill_name=skill_path.name,
|
||||
total_chunks=len(chunks),
|
||||
model=request.model,
|
||||
dimensions=dimensions,
|
||||
metadata={
|
||||
"skill_path": str(skill_path),
|
||||
"chunks": len(chunks),
|
||||
"content_length": len(skill_content)
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/cache/stats", response_model=dict)
|
||||
async def cache_stats():
|
||||
"""Get cache statistics."""
|
||||
if not cache:
|
||||
raise HTTPException(status_code=404, detail="Cache is disabled")
|
||||
|
||||
return cache.stats()
|
||||
|
||||
@app.post("/cache/clear", response_model=dict)
|
||||
async def clear_cache(
|
||||
model: Optional[str] = Query(None, description="Model to clear (all if not specified)")
|
||||
):
|
||||
"""Clear cache entries."""
|
||||
if not cache:
|
||||
raise HTTPException(status_code=404, detail="Cache is disabled")
|
||||
|
||||
deleted = cache.clear(model=model)
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"deleted": deleted,
|
||||
"model": model or "all"
|
||||
}
|
||||
|
||||
@app.post("/cache/clear-expired", response_model=dict)
|
||||
async def clear_expired():
|
||||
"""Clear expired cache entries."""
|
||||
if not cache:
|
||||
raise HTTPException(status_code=404, detail="Cache is disabled")
|
||||
|
||||
deleted = cache.clear_expired()
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"deleted": deleted
|
||||
}
|
||||
|
||||
else:
|
||||
print("Error: FastAPI not available. Install with: pip install fastapi uvicorn")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
if not FASTAPI_AVAILABLE:
|
||||
print("Error: FastAPI not available. Install with: pip install fastapi uvicorn")
|
||||
sys.exit(1)
|
||||
|
||||
# Get configuration from environment
|
||||
host = os.getenv("EMBEDDING_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("EMBEDDING_PORT", "8000"))
|
||||
reload = os.getenv("EMBEDDING_RELOAD", "false").lower() == "true"
|
||||
|
||||
print(f"🚀 Starting Embedding API server on {host}:{port}")
|
||||
print(f"📚 API documentation: http://{host}:{port}/docs")
|
||||
print(f"🔍 Cache enabled: {cache_enabled}")
|
||||
|
||||
if cache_enabled:
|
||||
print(f"💾 Cache database: {cache_db}")
|
||||
|
||||
uvicorn.run(
|
||||
"skill_seekers.embedding.server:app",
|
||||
host=host,
|
||||
port=port,
|
||||
reload=reload
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user