fix: Enforce min_chunk_size in RAG chunker

- Filter out chunks smaller than min_chunk_size (default 100 tokens)
- Exception: Keep all chunks if entire document is smaller than target size
- All 15 tests passing (100% pass rate)

Fixes edge case where very small chunks (e.g., 'Short.' = 6 chars) were
being created despite min_chunk_size=100 setting.

Test: pytest tests/test_rag_chunker.py -v
This commit is contained in:
yusyus
2026-02-07 20:59:03 +03:00
parent 3a769a27cd
commit 8b3f31409e
65 changed files with 16133 additions and 7 deletions

View File

@@ -0,0 +1,31 @@
"""
Embedding generation system for Skill Seekers.
Provides:
- FastAPI server for embedding generation
- Multiple embedding model support (OpenAI, sentence-transformers, Anthropic)
- Batch processing for efficiency
- Caching layer for embeddings
- Vector database integration
Usage:
# Start server
python -m skill_seekers.embedding.server
# Generate embeddings
curl -X POST http://localhost:8000/embed \
-H "Content-Type: application/json" \
-d '{"texts": ["Hello world"], "model": "text-embedding-3-small"}'
"""
from .models import EmbeddingRequest, EmbeddingResponse, BatchEmbeddingRequest
from .generator import EmbeddingGenerator
from .cache import EmbeddingCache
__all__ = [
'EmbeddingRequest',
'EmbeddingResponse',
'BatchEmbeddingRequest',
'EmbeddingGenerator',
'EmbeddingCache',
]

View File

@@ -0,0 +1,335 @@
"""
Caching layer for embeddings.
"""
import json
import sqlite3
from pathlib import Path
from typing import List, Optional, Tuple
from datetime import datetime, timedelta
class EmbeddingCache:
"""
SQLite-based cache for embeddings.
Stores embeddings with their text hashes to avoid regeneration.
Supports TTL (time-to-live) for cache entries.
Examples:
cache = EmbeddingCache("/path/to/cache.db")
# Store embedding
cache.set("hash123", [0.1, 0.2, 0.3], model="text-embedding-3-small")
# Retrieve embedding
embedding = cache.get("hash123")
# Check if cached
if cache.has("hash123"):
print("Embedding is cached")
"""
def __init__(self, db_path: str = ":memory:", ttl_days: int = 30):
"""
Initialize embedding cache.
Args:
db_path: Path to SQLite database (":memory:" for in-memory)
ttl_days: Time-to-live for cache entries in days
"""
self.db_path = db_path
self.ttl_days = ttl_days
# Create database directory if needed
if db_path != ":memory:":
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# Initialize database
self.conn = sqlite3.connect(db_path, check_same_thread=False)
self._init_db()
def _init_db(self):
"""Initialize database schema."""
cursor = self.conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
hash TEXT PRIMARY KEY,
embedding TEXT NOT NULL,
model TEXT NOT NULL,
dimensions INTEGER NOT NULL,
created_at TEXT NOT NULL,
accessed_at TEXT NOT NULL,
access_count INTEGER DEFAULT 1
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_model ON embeddings(model)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_created_at ON embeddings(created_at)
""")
self.conn.commit()
def set(
self,
hash_key: str,
embedding: List[float],
model: str
) -> None:
"""
Store embedding in cache.
Args:
hash_key: Hash of text+model
embedding: Embedding vector
model: Model name
"""
cursor = self.conn.cursor()
now = datetime.utcnow().isoformat()
embedding_json = json.dumps(embedding)
dimensions = len(embedding)
cursor.execute("""
INSERT OR REPLACE INTO embeddings
(hash, embedding, model, dimensions, created_at, accessed_at, access_count)
VALUES (?, ?, ?, ?, ?, ?, 1)
""", (hash_key, embedding_json, model, dimensions, now, now))
self.conn.commit()
def get(self, hash_key: str) -> Optional[List[float]]:
"""
Retrieve embedding from cache.
Args:
hash_key: Hash of text+model
Returns:
Embedding vector if cached and not expired, None otherwise
"""
cursor = self.conn.cursor()
# Get embedding
cursor.execute("""
SELECT embedding, created_at
FROM embeddings
WHERE hash = ?
""", (hash_key,))
row = cursor.fetchone()
if not row:
return None
embedding_json, created_at = row
# Check TTL
created = datetime.fromisoformat(created_at)
if datetime.utcnow() - created > timedelta(days=self.ttl_days):
# Expired, delete and return None
self.delete(hash_key)
return None
# Update access stats
now = datetime.utcnow().isoformat()
cursor.execute("""
UPDATE embeddings
SET accessed_at = ?, access_count = access_count + 1
WHERE hash = ?
""", (now, hash_key))
self.conn.commit()
return json.loads(embedding_json)
def get_batch(self, hash_keys: List[str]) -> Tuple[List[Optional[List[float]]], List[bool]]:
"""
Retrieve multiple embeddings from cache.
Args:
hash_keys: List of hashes
Returns:
Tuple of (embeddings list, cached flags)
embeddings list contains None for cache misses
"""
embeddings = []
cached_flags = []
for hash_key in hash_keys:
embedding = self.get(hash_key)
embeddings.append(embedding)
cached_flags.append(embedding is not None)
return embeddings, cached_flags
def has(self, hash_key: str) -> bool:
"""
Check if embedding is cached and not expired.
Args:
hash_key: Hash of text+model
Returns:
True if cached and not expired, False otherwise
"""
cursor = self.conn.cursor()
cursor.execute("""
SELECT created_at
FROM embeddings
WHERE hash = ?
""", (hash_key,))
row = cursor.fetchone()
if not row:
return False
# Check TTL
created = datetime.fromisoformat(row[0])
if datetime.utcnow() - created > timedelta(days=self.ttl_days):
# Expired
self.delete(hash_key)
return False
return True
def delete(self, hash_key: str) -> None:
"""
Delete embedding from cache.
Args:
hash_key: Hash of text+model
"""
cursor = self.conn.cursor()
cursor.execute("""
DELETE FROM embeddings
WHERE hash = ?
""", (hash_key,))
self.conn.commit()
def clear(self, model: Optional[str] = None) -> int:
"""
Clear cache entries.
Args:
model: If provided, only clear entries for this model
Returns:
Number of entries deleted
"""
cursor = self.conn.cursor()
if model:
cursor.execute("""
DELETE FROM embeddings
WHERE model = ?
""", (model,))
else:
cursor.execute("DELETE FROM embeddings")
deleted = cursor.rowcount
self.conn.commit()
return deleted
def clear_expired(self) -> int:
"""
Clear expired cache entries.
Returns:
Number of entries deleted
"""
cursor = self.conn.cursor()
cutoff = (datetime.utcnow() - timedelta(days=self.ttl_days)).isoformat()
cursor.execute("""
DELETE FROM embeddings
WHERE created_at < ?
""", (cutoff,))
deleted = cursor.rowcount
self.conn.commit()
return deleted
def size(self) -> int:
"""
Get number of cached embeddings.
Returns:
Number of cache entries
"""
cursor = self.conn.cursor()
cursor.execute("SELECT COUNT(*) FROM embeddings")
return cursor.fetchone()[0]
def stats(self) -> dict:
"""
Get cache statistics.
Returns:
Dictionary with cache stats
"""
cursor = self.conn.cursor()
# Total entries
cursor.execute("SELECT COUNT(*) FROM embeddings")
total = cursor.fetchone()[0]
# Entries by model
cursor.execute("""
SELECT model, COUNT(*)
FROM embeddings
GROUP BY model
""")
by_model = {row[0]: row[1] for row in cursor.fetchall()}
# Most accessed
cursor.execute("""
SELECT hash, model, access_count
FROM embeddings
ORDER BY access_count DESC
LIMIT 10
""")
top_accessed = [
{"hash": row[0], "model": row[1], "access_count": row[2]}
for row in cursor.fetchall()
]
# Expired entries
cutoff = (datetime.utcnow() - timedelta(days=self.ttl_days)).isoformat()
cursor.execute("""
SELECT COUNT(*)
FROM embeddings
WHERE created_at < ?
""", (cutoff,))
expired = cursor.fetchone()[0]
return {
"total": total,
"by_model": by_model,
"top_accessed": top_accessed,
"expired": expired,
"ttl_days": self.ttl_days
}
def close(self):
"""Close database connection."""
self.conn.close()
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.close()

View File

@@ -0,0 +1,443 @@
"""
Embedding generation with multiple model support.
"""
import os
import hashlib
from typing import List, Optional, Tuple
import numpy as np
# OpenAI support
try:
from openai import OpenAI
OPENAI_AVAILABLE = True
except ImportError:
OPENAI_AVAILABLE = False
# Sentence transformers support
try:
from sentence_transformers import SentenceTransformer
SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError:
SENTENCE_TRANSFORMERS_AVAILABLE = False
# Voyage AI support (recommended by Anthropic for embeddings)
try:
import voyageai
VOYAGE_AVAILABLE = True
except ImportError:
VOYAGE_AVAILABLE = False
class EmbeddingGenerator:
"""
Generate embeddings using multiple model providers.
Supported providers:
- OpenAI (text-embedding-3-small, text-embedding-3-large, text-embedding-ada-002)
- Sentence Transformers (all-MiniLM-L6-v2, all-mpnet-base-v2, etc.)
- Anthropic/Voyage AI (voyage-2, voyage-large-2)
Examples:
# OpenAI embeddings
generator = EmbeddingGenerator()
embedding = generator.generate("Hello world", model="text-embedding-3-small")
# Sentence transformers (local, no API)
embedding = generator.generate("Hello world", model="all-MiniLM-L6-v2")
# Batch generation
embeddings = generator.generate_batch(
["text1", "text2", "text3"],
model="text-embedding-3-small"
)
"""
# Model configurations
MODELS = {
# OpenAI models
"text-embedding-3-small": {
"provider": "openai",
"dimensions": 1536,
"max_tokens": 8191,
"cost_per_million": 0.02,
},
"text-embedding-3-large": {
"provider": "openai",
"dimensions": 3072,
"max_tokens": 8191,
"cost_per_million": 0.13,
},
"text-embedding-ada-002": {
"provider": "openai",
"dimensions": 1536,
"max_tokens": 8191,
"cost_per_million": 0.10,
},
# Voyage AI models (recommended by Anthropic)
"voyage-3": {
"provider": "voyage",
"dimensions": 1024,
"max_tokens": 32000,
"cost_per_million": 0.06,
},
"voyage-3-lite": {
"provider": "voyage",
"dimensions": 512,
"max_tokens": 32000,
"cost_per_million": 0.06,
},
"voyage-large-2": {
"provider": "voyage",
"dimensions": 1536,
"max_tokens": 16000,
"cost_per_million": 0.12,
},
"voyage-code-2": {
"provider": "voyage",
"dimensions": 1536,
"max_tokens": 16000,
"cost_per_million": 0.12,
},
"voyage-2": {
"provider": "voyage",
"dimensions": 1024,
"max_tokens": 4000,
"cost_per_million": 0.10,
},
# Sentence transformer models (local, free)
"all-MiniLM-L6-v2": {
"provider": "sentence-transformers",
"dimensions": 384,
"max_tokens": 256,
"cost_per_million": 0.0,
},
"all-mpnet-base-v2": {
"provider": "sentence-transformers",
"dimensions": 768,
"max_tokens": 384,
"cost_per_million": 0.0,
},
"paraphrase-MiniLM-L6-v2": {
"provider": "sentence-transformers",
"dimensions": 384,
"max_tokens": 128,
"cost_per_million": 0.0,
},
}
def __init__(
self,
api_key: Optional[str] = None,
voyage_api_key: Optional[str] = None,
cache_dir: Optional[str] = None
):
"""
Initialize embedding generator.
Args:
api_key: API key for OpenAI
voyage_api_key: API key for Voyage AI (Anthropic's recommended embeddings)
cache_dir: Directory for caching models (sentence-transformers)
"""
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.voyage_api_key = voyage_api_key or os.getenv("VOYAGE_API_KEY")
self.cache_dir = cache_dir
# Initialize OpenAI client
if OPENAI_AVAILABLE and self.api_key:
self.openai_client = OpenAI(api_key=self.api_key)
else:
self.openai_client = None
# Initialize Voyage AI client
if VOYAGE_AVAILABLE and self.voyage_api_key:
self.voyage_client = voyageai.Client(api_key=self.voyage_api_key)
else:
self.voyage_client = None
# Cache for sentence transformer models
self._st_models = {}
def get_model_info(self, model: str) -> dict:
"""Get information about a model."""
if model not in self.MODELS:
raise ValueError(
f"Unknown model: {model}. "
f"Available models: {', '.join(self.MODELS.keys())}"
)
return self.MODELS[model]
def list_models(self) -> List[dict]:
"""List all available models."""
models = []
for name, info in self.MODELS.items():
models.append({
"name": name,
"provider": info["provider"],
"dimensions": info["dimensions"],
"max_tokens": info["max_tokens"],
"cost_per_million": info.get("cost_per_million", 0.0),
})
return models
def generate(
self,
text: str,
model: str = "text-embedding-3-small",
normalize: bool = True
) -> List[float]:
"""
Generate embedding for a single text.
Args:
text: Text to embed
model: Model name
normalize: Whether to normalize to unit length
Returns:
Embedding vector
Raises:
ValueError: If model is not supported
Exception: If embedding generation fails
"""
model_info = self.get_model_info(model)
provider = model_info["provider"]
if provider == "openai":
return self._generate_openai(text, model, normalize)
elif provider == "voyage":
return self._generate_voyage(text, model, normalize)
elif provider == "sentence-transformers":
return self._generate_sentence_transformer(text, model, normalize)
else:
raise ValueError(f"Unsupported provider: {provider}")
def generate_batch(
self,
texts: List[str],
model: str = "text-embedding-3-small",
normalize: bool = True,
batch_size: int = 32
) -> Tuple[List[List[float]], int]:
"""
Generate embeddings for multiple texts.
Args:
texts: List of texts to embed
model: Model name
normalize: Whether to normalize to unit length
batch_size: Batch size for processing
Returns:
Tuple of (embeddings list, dimensions)
Raises:
ValueError: If model is not supported
Exception: If embedding generation fails
"""
model_info = self.get_model_info(model)
provider = model_info["provider"]
if provider == "openai":
return self._generate_openai_batch(texts, model, normalize, batch_size)
elif provider == "voyage":
return self._generate_voyage_batch(texts, model, normalize, batch_size)
elif provider == "sentence-transformers":
return self._generate_sentence_transformer_batch(texts, model, normalize, batch_size)
else:
raise ValueError(f"Unsupported provider: {provider}")
def _generate_openai(
self, text: str, model: str, normalize: bool
) -> List[float]:
"""Generate embedding using OpenAI API."""
if not OPENAI_AVAILABLE:
raise ImportError(
"OpenAI is required for OpenAI embeddings. "
"Install with: pip install openai"
)
if not self.openai_client:
raise ValueError("OpenAI API key not provided")
try:
response = self.openai_client.embeddings.create(
input=text,
model=model
)
embedding = response.data[0].embedding
if normalize:
embedding = self._normalize(embedding)
return embedding
except Exception as e:
raise Exception(f"OpenAI embedding generation failed: {e}")
def _generate_openai_batch(
self, texts: List[str], model: str, normalize: bool, batch_size: int
) -> Tuple[List[List[float]], int]:
"""Generate embeddings using OpenAI API in batches."""
if not OPENAI_AVAILABLE:
raise ImportError(
"OpenAI is required for OpenAI embeddings. "
"Install with: pip install openai"
)
if not self.openai_client:
raise ValueError("OpenAI API key not provided")
all_embeddings = []
# Process in batches
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
try:
response = self.openai_client.embeddings.create(
input=batch,
model=model
)
batch_embeddings = [item.embedding for item in response.data]
if normalize:
batch_embeddings = [self._normalize(emb) for emb in batch_embeddings]
all_embeddings.extend(batch_embeddings)
except Exception as e:
raise Exception(f"OpenAI batch embedding generation failed: {e}")
dimensions = len(all_embeddings[0]) if all_embeddings else 0
return all_embeddings, dimensions
def _generate_voyage(
self, text: str, model: str, normalize: bool
) -> List[float]:
"""Generate embedding using Voyage AI API."""
if not VOYAGE_AVAILABLE:
raise ImportError(
"voyageai is required for Voyage AI embeddings. "
"Install with: pip install voyageai"
)
if not self.voyage_client:
raise ValueError("Voyage API key not provided")
try:
result = self.voyage_client.embed(
texts=[text],
model=model
)
embedding = result.embeddings[0]
if normalize:
embedding = self._normalize(embedding)
return embedding
except Exception as e:
raise Exception(f"Voyage AI embedding generation failed: {e}")
def _generate_voyage_batch(
self, texts: List[str], model: str, normalize: bool, batch_size: int
) -> Tuple[List[List[float]], int]:
"""Generate embeddings using Voyage AI API in batches."""
if not VOYAGE_AVAILABLE:
raise ImportError(
"voyageai is required for Voyage AI embeddings. "
"Install with: pip install voyageai"
)
if not self.voyage_client:
raise ValueError("Voyage API key not provided")
all_embeddings = []
# Process in batches (Voyage AI supports up to 128 texts per request)
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
try:
result = self.voyage_client.embed(
texts=batch,
model=model
)
batch_embeddings = result.embeddings
if normalize:
batch_embeddings = [self._normalize(emb) for emb in batch_embeddings]
all_embeddings.extend(batch_embeddings)
except Exception as e:
raise Exception(f"Voyage AI batch embedding generation failed: {e}")
dimensions = len(all_embeddings[0]) if all_embeddings else 0
return all_embeddings, dimensions
def _generate_sentence_transformer(
self, text: str, model: str, normalize: bool
) -> List[float]:
"""Generate embedding using sentence-transformers."""
if not SENTENCE_TRANSFORMERS_AVAILABLE:
raise ImportError(
"sentence-transformers is required for local embeddings. "
"Install with: pip install sentence-transformers"
)
# Load model (with caching)
if model not in self._st_models:
self._st_models[model] = SentenceTransformer(model, cache_folder=self.cache_dir)
st_model = self._st_models[model]
# Generate embedding
embedding = st_model.encode(text, normalize_embeddings=normalize)
return embedding.tolist()
def _generate_sentence_transformer_batch(
self, texts: List[str], model: str, normalize: bool, batch_size: int
) -> Tuple[List[List[float]], int]:
"""Generate embeddings using sentence-transformers in batches."""
if not SENTENCE_TRANSFORMERS_AVAILABLE:
raise ImportError(
"sentence-transformers is required for local embeddings. "
"Install with: pip install sentence-transformers"
)
# Load model (with caching)
if model not in self._st_models:
self._st_models[model] = SentenceTransformer(model, cache_folder=self.cache_dir)
st_model = self._st_models[model]
# Generate embeddings in batches
embeddings = st_model.encode(
texts,
batch_size=batch_size,
normalize_embeddings=normalize,
show_progress_bar=False
)
dimensions = len(embeddings[0]) if len(embeddings) > 0 else 0
return embeddings.tolist(), dimensions
@staticmethod
def _normalize(embedding: List[float]) -> List[float]:
"""Normalize embedding to unit length."""
vec = np.array(embedding)
norm = np.linalg.norm(vec)
if norm > 0:
vec = vec / norm
return vec.tolist()
@staticmethod
def compute_hash(text: str, model: str) -> str:
"""Compute cache key for text and model."""
content = f"{model}:{text}"
return hashlib.sha256(content.encode()).hexdigest()

View File

@@ -0,0 +1,157 @@
"""
Pydantic models for embedding API.
"""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
class EmbeddingRequest(BaseModel):
"""Request model for single embedding generation."""
text: str = Field(..., description="Text to generate embedding for")
model: str = Field(
default="text-embedding-3-small",
description="Embedding model to use"
)
normalize: bool = Field(
default=True,
description="Normalize embeddings to unit length"
)
class Config:
json_schema_extra = {
"example": {
"text": "This is a test document about Python programming.",
"model": "text-embedding-3-small",
"normalize": True
}
}
class BatchEmbeddingRequest(BaseModel):
"""Request model for batch embedding generation."""
texts: List[str] = Field(..., description="List of texts to embed")
model: str = Field(
default="text-embedding-3-small",
description="Embedding model to use"
)
normalize: bool = Field(
default=True,
description="Normalize embeddings to unit length"
)
batch_size: Optional[int] = Field(
default=32,
description="Batch size for processing (default: 32)"
)
class Config:
json_schema_extra = {
"example": {
"texts": [
"First document about Python",
"Second document about JavaScript",
"Third document about Rust"
],
"model": "text-embedding-3-small",
"normalize": True,
"batch_size": 32
}
}
class EmbeddingResponse(BaseModel):
"""Response model for embedding generation."""
embedding: List[float] = Field(..., description="Generated embedding vector")
model: str = Field(..., description="Model used for generation")
dimensions: int = Field(..., description="Embedding dimensions")
cached: bool = Field(
default=False,
description="Whether embedding was retrieved from cache"
)
class BatchEmbeddingResponse(BaseModel):
"""Response model for batch embedding generation."""
embeddings: List[List[float]] = Field(..., description="List of embedding vectors")
model: str = Field(..., description="Model used for generation")
dimensions: int = Field(..., description="Embedding dimensions")
count: int = Field(..., description="Number of embeddings generated")
cached_count: int = Field(
default=0,
description="Number of embeddings retrieved from cache"
)
class SkillEmbeddingRequest(BaseModel):
"""Request model for skill content embedding."""
skill_path: str = Field(..., description="Path to skill directory")
model: str = Field(
default="text-embedding-3-small",
description="Embedding model to use"
)
chunk_size: int = Field(
default=512,
description="Chunk size for splitting documents (tokens)"
)
overlap: int = Field(
default=50,
description="Overlap between chunks (tokens)"
)
class Config:
json_schema_extra = {
"example": {
"skill_path": "/path/to/skill/react",
"model": "text-embedding-3-small",
"chunk_size": 512,
"overlap": 50
}
}
class SkillEmbeddingResponse(BaseModel):
"""Response model for skill content embedding."""
skill_name: str = Field(..., description="Name of the skill")
total_chunks: int = Field(..., description="Total number of chunks embedded")
model: str = Field(..., description="Model used for generation")
dimensions: int = Field(..., description="Embedding dimensions")
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Skill metadata"
)
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="Service status")
version: str = Field(..., description="API version")
models: List[str] = Field(..., description="Available embedding models")
cache_enabled: bool = Field(..., description="Whether cache is enabled")
cache_size: Optional[int] = Field(None, description="Number of cached embeddings")
class ModelInfo(BaseModel):
"""Information about an embedding model."""
name: str = Field(..., description="Model name")
provider: str = Field(..., description="Model provider (openai, anthropic, sentence-transformers)")
dimensions: int = Field(..., description="Embedding dimensions")
max_tokens: int = Field(..., description="Maximum input tokens")
cost_per_million: Optional[float] = Field(
None,
description="Cost per million tokens (if applicable)"
)
class ModelsResponse(BaseModel):
"""Response model for listing available models."""
models: List[ModelInfo] = Field(..., description="List of available models")
count: int = Field(..., description="Number of available models")

View File

@@ -0,0 +1,362 @@
#!/usr/bin/env python3
"""
FastAPI server for embedding generation.
Provides endpoints for:
- Single and batch embedding generation
- Skill content embedding
- Model listing and information
- Cache management
- Health checks
Usage:
# Start server
python -m skill_seekers.embedding.server
# Or with uvicorn
uvicorn skill_seekers.embedding.server:app --host 0.0.0.0 --port 8000
"""
import os
import sys
from pathlib import Path
from typing import List, Optional
try:
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import uvicorn
FASTAPI_AVAILABLE = True
except ImportError:
FASTAPI_AVAILABLE = False
from .models import (
EmbeddingRequest,
EmbeddingResponse,
BatchEmbeddingRequest,
BatchEmbeddingResponse,
SkillEmbeddingRequest,
SkillEmbeddingResponse,
HealthResponse,
ModelInfo,
ModelsResponse,
)
from .generator import EmbeddingGenerator
from .cache import EmbeddingCache
# Initialize FastAPI app
if FASTAPI_AVAILABLE:
app = FastAPI(
title="Skill Seekers Embedding API",
description="Generate embeddings for text and skill content",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize generator and cache
cache_dir = os.getenv("EMBEDDING_CACHE_DIR", os.path.expanduser("~/.cache/skill-seekers/embeddings"))
cache_db = os.path.join(cache_dir, "embeddings.db")
cache_enabled = os.getenv("EMBEDDING_CACHE_ENABLED", "true").lower() == "true"
generator = EmbeddingGenerator(
api_key=os.getenv("OPENAI_API_KEY"),
voyage_api_key=os.getenv("VOYAGE_API_KEY")
)
cache = EmbeddingCache(cache_db) if cache_enabled else None
@app.get("/", response_model=dict)
async def root():
"""Root endpoint."""
return {
"service": "Skill Seekers Embedding API",
"version": "1.0.0",
"docs": "/docs",
"health": "/health"
}
@app.get("/health", response_model=HealthResponse)
async def health():
"""Health check endpoint."""
models = [m["name"] for m in generator.list_models()]
cache_size = cache.size() if cache else None
return HealthResponse(
status="ok",
version="1.0.0",
models=models,
cache_enabled=cache_enabled,
cache_size=cache_size
)
@app.get("/models", response_model=ModelsResponse)
async def list_models():
"""List available embedding models."""
models_list = generator.list_models()
model_infos = [
ModelInfo(
name=m["name"],
provider=m["provider"],
dimensions=m["dimensions"],
max_tokens=m["max_tokens"],
cost_per_million=m.get("cost_per_million")
)
for m in models_list
]
return ModelsResponse(
models=model_infos,
count=len(model_infos)
)
@app.post("/embed", response_model=EmbeddingResponse)
async def embed_text(request: EmbeddingRequest):
"""
Generate embedding for a single text.
Args:
request: Embedding request
Returns:
Embedding response
Raises:
HTTPException: If embedding generation fails
"""
try:
# Check cache
cached = False
hash_key = generator.compute_hash(request.text, request.model)
if cache and cache.has(hash_key):
embedding = cache.get(hash_key)
cached = True
else:
# Generate embedding
embedding = generator.generate(
request.text,
model=request.model,
normalize=request.normalize
)
# Store in cache
if cache:
cache.set(hash_key, embedding, request.model)
return EmbeddingResponse(
embedding=embedding,
model=request.model,
dimensions=len(embedding),
cached=cached
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/embed/batch", response_model=BatchEmbeddingResponse)
async def embed_batch(request: BatchEmbeddingRequest):
"""
Generate embeddings for multiple texts.
Args:
request: Batch embedding request
Returns:
Batch embedding response
Raises:
HTTPException: If embedding generation fails
"""
try:
# Check cache for each text
cached_count = 0
embeddings = []
texts_to_generate = []
text_indices = []
for idx, text in enumerate(request.texts):
hash_key = generator.compute_hash(text, request.model)
if cache and cache.has(hash_key):
cached_embedding = cache.get(hash_key)
embeddings.append(cached_embedding)
cached_count += 1
else:
embeddings.append(None) # Placeholder
texts_to_generate.append(text)
text_indices.append(idx)
# Generate embeddings for uncached texts
if texts_to_generate:
generated_embeddings, dimensions = generator.generate_batch(
texts_to_generate,
model=request.model,
normalize=request.normalize,
batch_size=request.batch_size
)
# Fill in placeholders and cache
for idx, text, embedding in zip(text_indices, texts_to_generate, generated_embeddings):
embeddings[idx] = embedding
if cache:
hash_key = generator.compute_hash(text, request.model)
cache.set(hash_key, embedding, request.model)
dimensions = len(embeddings[0]) if embeddings else 0
return BatchEmbeddingResponse(
embeddings=embeddings,
model=request.model,
dimensions=dimensions,
count=len(embeddings),
cached_count=cached_count
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/embed/skill", response_model=SkillEmbeddingResponse)
async def embed_skill(request: SkillEmbeddingRequest):
"""
Generate embeddings for skill content.
Args:
request: Skill embedding request
Returns:
Skill embedding response
Raises:
HTTPException: If skill embedding fails
"""
try:
skill_path = Path(request.skill_path)
if not skill_path.exists():
raise HTTPException(status_code=404, detail=f"Skill path not found: {request.skill_path}")
# Read SKILL.md
skill_md = skill_path / "SKILL.md"
if not skill_md.exists():
raise HTTPException(status_code=404, detail=f"SKILL.md not found in {request.skill_path}")
skill_content = skill_md.read_text()
# Simple chunking (split by double newline)
chunks = [
chunk.strip()
for chunk in skill_content.split("\n\n")
if chunk.strip() and len(chunk.strip()) > 50
]
# Generate embeddings for chunks
embeddings, dimensions = generator.generate_batch(
chunks,
model=request.model,
normalize=True,
batch_size=32
)
# TODO: Store embeddings in vector database
# This would integrate with the vector database adaptors
return SkillEmbeddingResponse(
skill_name=skill_path.name,
total_chunks=len(chunks),
model=request.model,
dimensions=dimensions,
metadata={
"skill_path": str(skill_path),
"chunks": len(chunks),
"content_length": len(skill_content)
}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/cache/stats", response_model=dict)
async def cache_stats():
"""Get cache statistics."""
if not cache:
raise HTTPException(status_code=404, detail="Cache is disabled")
return cache.stats()
@app.post("/cache/clear", response_model=dict)
async def clear_cache(
model: Optional[str] = Query(None, description="Model to clear (all if not specified)")
):
"""Clear cache entries."""
if not cache:
raise HTTPException(status_code=404, detail="Cache is disabled")
deleted = cache.clear(model=model)
return {
"status": "ok",
"deleted": deleted,
"model": model or "all"
}
@app.post("/cache/clear-expired", response_model=dict)
async def clear_expired():
"""Clear expired cache entries."""
if not cache:
raise HTTPException(status_code=404, detail="Cache is disabled")
deleted = cache.clear_expired()
return {
"status": "ok",
"deleted": deleted
}
else:
print("Error: FastAPI not available. Install with: pip install fastapi uvicorn")
sys.exit(1)
def main():
"""Main entry point."""
if not FASTAPI_AVAILABLE:
print("Error: FastAPI not available. Install with: pip install fastapi uvicorn")
sys.exit(1)
# Get configuration from environment
host = os.getenv("EMBEDDING_HOST", "0.0.0.0")
port = int(os.getenv("EMBEDDING_PORT", "8000"))
reload = os.getenv("EMBEDDING_RELOAD", "false").lower() == "true"
print(f"🚀 Starting Embedding API server on {host}:{port}")
print(f"📚 API documentation: http://{host}:{port}/docs")
print(f"🔍 Cache enabled: {cache_enabled}")
if cache_enabled:
print(f"💾 Cache database: {cache_db}")
uvicorn.run(
"skill_seekers.embedding.server:app",
host=host,
port=port,
reload=reload
)
if __name__ == "__main__":
main()