feat: Add custom embedding pipeline (Task #17)

- Multi-provider support (OpenAI, Local)
- Batch processing with configurable batch size
- Memory and disk caching for efficiency
- Cost tracking and estimation
- Dimension validation
- 18 tests passing (100%)

Files:
- embedding_pipeline.py: Core pipeline engine
- test_embedding_pipeline.py: Comprehensive tests

Features:
- EmbeddingProvider abstraction
- OpenAIEmbeddingProvider with pricing
- LocalEmbeddingProvider (simulated)
- EmbeddingCache (memory + disk)
- CostTracker for API usage
- Batch processing optimization

Supported Models:
- text-embedding-ada-002 (1536d, $0.10/1M tokens)
- text-embedding-3-small (1536d, $0.02/1M tokens)
- text-embedding-3-large (3072d, $0.13/1M tokens)
- Local models (any dimension, free)

Week 2: 8/9 tasks complete (89%)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
yusyus
2026-02-07 13:48:05 +03:00
parent 261f28f7ee
commit b475b51ad1
2 changed files with 753 additions and 0 deletions

View File

@@ -0,0 +1,430 @@
#!/usr/bin/env python3
"""
Custom Embedding Pipeline
Provides flexible embedding generation with multiple providers,
batch processing, caching, and cost tracking.
"""
import hashlib
import json
import time
from pathlib import Path
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
import numpy as np
@dataclass
class EmbeddingConfig:
"""Configuration for embedding generation."""
provider: str # 'openai', 'cohere', 'huggingface', 'local'
model: str
dimension: int
batch_size: int = 100
cache_dir: Optional[Path] = None
max_retries: int = 3
retry_delay: float = 1.0
@dataclass
class EmbeddingResult:
"""Result of embedding generation."""
embeddings: List[List[float]]
metadata: Dict[str, Any] = field(default_factory=dict)
cached_count: int = 0
generated_count: int = 0
total_time: float = 0.0
cost_estimate: float = 0.0
@dataclass
class CostTracker:
"""Track embedding generation costs."""
total_tokens: int = 0
total_requests: int = 0
cache_hits: int = 0
cache_misses: int = 0
estimated_cost: float = 0.0
def add_request(self, token_count: int, cost: float, from_cache: bool = False):
"""Add a request to tracking."""
self.total_requests += 1
self.total_tokens += token_count
self.estimated_cost += cost
if from_cache:
self.cache_hits += 1
else:
self.cache_misses += 1
def get_stats(self) -> Dict[str, Any]:
"""Get statistics."""
cache_rate = (self.cache_hits / self.total_requests * 100) if self.total_requests > 0 else 0
return {
'total_requests': self.total_requests,
'total_tokens': self.total_tokens,
'cache_hits': self.cache_hits,
'cache_misses': self.cache_misses,
'cache_rate': f"{cache_rate:.1f}%",
'estimated_cost': f"${self.estimated_cost:.4f}"
}
class EmbeddingProvider(ABC):
"""Abstract base class for embedding providers."""
@abstractmethod
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for texts."""
pass
@abstractmethod
def get_dimension(self) -> int:
"""Get embedding dimension."""
pass
@abstractmethod
def estimate_cost(self, token_count: int) -> float:
"""Estimate cost for token count."""
pass
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""OpenAI embedding provider."""
# Pricing per 1M tokens (as of 2026)
PRICING = {
'text-embedding-ada-002': 0.10,
'text-embedding-3-small': 0.02,
'text-embedding-3-large': 0.13,
}
DIMENSIONS = {
'text-embedding-ada-002': 1536,
'text-embedding-3-small': 1536,
'text-embedding-3-large': 3072,
}
def __init__(self, model: str = 'text-embedding-ada-002', api_key: Optional[str] = None):
"""Initialize OpenAI provider."""
self.model = model
self.api_key = api_key
self._client = None
def _get_client(self):
"""Lazy load OpenAI client."""
if self._client is None:
try:
from openai import OpenAI
self._client = OpenAI(api_key=self.api_key)
except ImportError:
raise ImportError("OpenAI package not installed. Install with: pip install openai")
return self._client
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings using OpenAI."""
client = self._get_client()
embeddings = []
for text in texts:
response = client.embeddings.create(
model=self.model,
input=text
)
embeddings.append(response.data[0].embedding)
return embeddings
def get_dimension(self) -> int:
"""Get embedding dimension."""
return self.DIMENSIONS.get(self.model, 1536)
def estimate_cost(self, token_count: int) -> float:
"""Estimate cost."""
price_per_million = self.PRICING.get(self.model, 0.10)
return (token_count / 1_000_000) * price_per_million
class LocalEmbeddingProvider(EmbeddingProvider):
"""Local embedding provider (simulated)."""
def __init__(self, dimension: int = 384):
"""Initialize local provider."""
self.dimension = dimension
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings using local model (simulated)."""
# In production, would use sentence-transformers or similar
embeddings = []
for text in texts:
# Deterministic random based on text hash
seed = int(hashlib.md5(text.encode()).hexdigest()[:8], 16)
np.random.seed(seed)
embedding = np.random.randn(self.dimension).tolist()
embeddings.append(embedding)
return embeddings
def get_dimension(self) -> int:
"""Get embedding dimension."""
return self.dimension
def estimate_cost(self, token_count: int) -> float:
"""Local models are free."""
return 0.0
class EmbeddingCache:
"""Cache for embeddings to avoid recomputation."""
def __init__(self, cache_dir: Optional[Path] = None):
"""Initialize cache."""
self.cache_dir = Path(cache_dir) if cache_dir else None
self._memory_cache: Dict[str, List[float]] = {}
if self.cache_dir:
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _compute_hash(self, text: str, model: str) -> str:
"""Compute cache key."""
key = f"{model}:{text}"
return hashlib.sha256(key.encode()).hexdigest()
def get(self, text: str, model: str) -> Optional[List[float]]:
"""Get embedding from cache."""
cache_key = self._compute_hash(text, model)
# Check memory cache
if cache_key in self._memory_cache:
return self._memory_cache[cache_key]
# Check disk cache
if self.cache_dir:
cache_file = self.cache_dir / f"{cache_key}.json"
if cache_file.exists():
try:
data = json.loads(cache_file.read_text())
embedding = data['embedding']
self._memory_cache[cache_key] = embedding
return embedding
except Exception:
pass
return None
def set(self, text: str, model: str, embedding: List[float]) -> None:
"""Store embedding in cache."""
cache_key = self._compute_hash(text, model)
# Store in memory
self._memory_cache[cache_key] = embedding
# Store on disk
if self.cache_dir:
cache_file = self.cache_dir / f"{cache_key}.json"
try:
cache_file.write_text(json.dumps({
'text_hash': cache_key,
'model': model,
'embedding': embedding,
'timestamp': time.time()
}))
except Exception as e:
print(f"⚠️ Warning: Failed to write cache: {e}")
class EmbeddingPipeline:
"""
Flexible embedding generation pipeline.
Supports multiple providers, batch processing, caching, and cost tracking.
"""
def __init__(self, config: EmbeddingConfig):
"""Initialize pipeline."""
self.config = config
self.provider = self._create_provider()
self.cache = EmbeddingCache(config.cache_dir)
self.cost_tracker = CostTracker()
def _create_provider(self) -> EmbeddingProvider:
"""Create provider based on config."""
if self.config.provider == 'openai':
return OpenAIEmbeddingProvider(self.config.model)
elif self.config.provider == 'local':
return LocalEmbeddingProvider(self.config.dimension)
else:
raise ValueError(f"Unknown provider: {self.config.provider}")
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count (rough approximation)."""
# Rough estimate: 1 token ≈ 4 characters
return len(text) // 4
def generate_batch(
self,
texts: List[str],
show_progress: bool = True
) -> EmbeddingResult:
"""
Generate embeddings for batch of texts.
Args:
texts: List of texts to embed
show_progress: Show progress output
Returns:
EmbeddingResult with embeddings and metadata
"""
start_time = time.time()
embeddings = []
cached_count = 0
generated_count = 0
if show_progress:
print(f"🔄 Generating embeddings...")
print(f" Texts: {len(texts)}")
print(f" Provider: {self.config.provider}")
print(f" Model: {self.config.model}")
print(f" Batch size: {self.config.batch_size}")
# Process in batches
for i in range(0, len(texts), self.config.batch_size):
batch = texts[i:i + self.config.batch_size]
batch_embeddings = []
to_generate = []
to_generate_indices = []
# Check cache
for j, text in enumerate(batch):
cached = self.cache.get(text, self.config.model)
if cached:
batch_embeddings.append(cached)
cached_count += 1
else:
to_generate.append(text)
to_generate_indices.append(j)
# Generate missing embeddings
if to_generate:
new_embeddings = self.provider.generate_embeddings(to_generate)
# Store in cache
for text, embedding in zip(to_generate, new_embeddings):
self.cache.set(text, self.config.model, embedding)
# Track cost
total_tokens = sum(self._estimate_tokens(t) for t in to_generate)
cost = self.provider.estimate_cost(total_tokens)
self.cost_tracker.add_request(total_tokens, cost, from_cache=False)
# Merge with cached
for idx, embedding in zip(to_generate_indices, new_embeddings):
batch_embeddings.insert(idx, embedding)
generated_count += len(to_generate)
embeddings.extend(batch_embeddings)
if show_progress and len(texts) > self.config.batch_size:
progress = min(i + self.config.batch_size, len(texts))
print(f" Progress: {progress}/{len(texts)} ({progress/len(texts)*100:.1f}%)")
total_time = time.time() - start_time
if show_progress:
print(f"\n✅ Embeddings generated!")
print(f" Total: {len(embeddings)}")
print(f" Cached: {cached_count}")
print(f" Generated: {generated_count}")
print(f" Time: {total_time:.2f}s")
if self.config.provider != 'local':
stats = self.cost_tracker.get_stats()
print(f" Cost: {stats['estimated_cost']}")
return EmbeddingResult(
embeddings=embeddings,
metadata={
'provider': self.config.provider,
'model': self.config.model,
'dimension': self.provider.get_dimension()
},
cached_count=cached_count,
generated_count=generated_count,
total_time=total_time,
cost_estimate=self.cost_tracker.estimated_cost
)
def validate_dimensions(self, embeddings: List[List[float]]) -> bool:
"""
Validate embedding dimensions.
Args:
embeddings: List of embeddings to validate
Returns:
True if valid
"""
expected_dim = self.provider.get_dimension()
for i, embedding in enumerate(embeddings):
if len(embedding) != expected_dim:
print(f"❌ Dimension mismatch at index {i}: "
f"expected {expected_dim}, got {len(embedding)}")
return False
return True
def get_cost_stats(self) -> Dict[str, Any]:
"""Get cost tracking statistics."""
return self.cost_tracker.get_stats()
def example_usage():
"""Example usage of embedding pipeline."""
from pathlib import Path
# Configure pipeline
config = EmbeddingConfig(
provider='local', # Use 'openai' for production
model='text-embedding-ada-002',
dimension=384,
batch_size=50,
cache_dir=Path("output/.embeddings_cache")
)
# Initialize pipeline
pipeline = EmbeddingPipeline(config)
# Generate embeddings
texts = [
"This is the first document.",
"Here is the second document.",
"And this is the third document.",
]
result = pipeline.generate_batch(texts)
print(f"\n📊 Results:")
print(f" Embeddings: {len(result.embeddings)}")
print(f" Dimension: {len(result.embeddings[0])}")
print(f" Cached: {result.cached_count}")
print(f" Generated: {result.generated_count}")
# Validate
is_valid = pipeline.validate_dimensions(result.embeddings)
print(f" Valid: {is_valid}")
# Cost stats
stats = pipeline.get_cost_stats()
print(f"\n💰 Cost Stats:")
for key, value in stats.items():
print(f" {key}: {value}")
if __name__ == "__main__":
example_usage()

View File

@@ -0,0 +1,323 @@
#!/usr/bin/env python3
"""
Tests for custom embedding pipeline.
Validates:
- Multiple provider support
- Batch processing
- Caching mechanism
- Cost tracking
- Dimension validation
"""
import pytest
from pathlib import Path
import sys
import tempfile
import json
# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from skill_seekers.cli.embedding_pipeline import (
EmbeddingConfig,
EmbeddingPipeline,
LocalEmbeddingProvider,
EmbeddingCache,
CostTracker
)
def test_local_provider_generation():
"""Test local embedding provider."""
provider = LocalEmbeddingProvider(dimension=128)
texts = ["test document 1", "test document 2"]
embeddings = provider.generate_embeddings(texts)
assert len(embeddings) == 2
assert len(embeddings[0]) == 128
assert len(embeddings[1]) == 128
def test_local_provider_deterministic():
"""Test local provider generates deterministic embeddings."""
provider = LocalEmbeddingProvider(dimension=64)
text = "same text"
emb1 = provider.generate_embeddings([text])[0]
emb2 = provider.generate_embeddings([text])[0]
# Should be identical for same text
assert emb1 == emb2
def test_local_provider_cost():
"""Test local provider cost estimation."""
provider = LocalEmbeddingProvider()
cost = provider.estimate_cost(1000)
assert cost == 0.0 # Local is free
def test_cache_memory():
"""Test memory cache functionality."""
cache = EmbeddingCache()
text = "test text"
model = "test-model"
embedding = [0.1, 0.2, 0.3]
# Set and get
cache.set(text, model, embedding)
retrieved = cache.get(text, model)
assert retrieved == embedding
def test_cache_miss():
"""Test cache miss returns None."""
cache = EmbeddingCache()
result = cache.get("nonexistent", "model")
assert result is None
def test_cache_disk():
"""Test disk cache functionality."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = EmbeddingCache(cache_dir=Path(tmpdir))
text = "test text"
model = "test-model"
embedding = [0.1, 0.2, 0.3]
# Set
cache.set(text, model, embedding)
# Create new cache instance (clears memory)
cache2 = EmbeddingCache(cache_dir=Path(tmpdir))
# Should retrieve from disk
retrieved = cache2.get(text, model)
assert retrieved == embedding
def test_cost_tracker():
"""Test cost tracking."""
tracker = CostTracker()
# Add requests
tracker.add_request(token_count=1000, cost=0.01, from_cache=False)
tracker.add_request(token_count=500, cost=0.005, from_cache=True)
stats = tracker.get_stats()
assert stats['total_requests'] == 2
assert stats['total_tokens'] == 1500
assert stats['cache_hits'] == 1
assert stats['cache_misses'] == 1
assert '50.0%' in stats['cache_rate']
def test_pipeline_initialization():
"""Test pipeline initialization."""
config = EmbeddingConfig(
provider='local',
model='test-model',
dimension=128,
batch_size=10
)
pipeline = EmbeddingPipeline(config)
assert pipeline.config == config
assert pipeline.provider is not None
assert pipeline.cache is not None
def test_pipeline_generate_batch():
"""Test batch embedding generation."""
config = EmbeddingConfig(
provider='local',
model='test-model',
dimension=64,
batch_size=2
)
pipeline = EmbeddingPipeline(config)
texts = ["doc 1", "doc 2", "doc 3"]
result = pipeline.generate_batch(texts, show_progress=False)
assert len(result.embeddings) == 3
assert len(result.embeddings[0]) == 64
assert result.generated_count == 3
assert result.cached_count == 0
def test_pipeline_caching():
"""Test pipeline uses caching."""
with tempfile.TemporaryDirectory() as tmpdir:
config = EmbeddingConfig(
provider='local',
model='test-model',
dimension=32,
batch_size=10,
cache_dir=Path(tmpdir)
)
pipeline = EmbeddingPipeline(config)
texts = ["same doc", "same doc", "different doc"]
# First generation
result1 = pipeline.generate_batch(texts, show_progress=False)
assert result1.cached_count == 0
assert result1.generated_count == 3
# Second generation (should use cache)
result2 = pipeline.generate_batch(texts, show_progress=False)
assert result2.cached_count == 3
assert result2.generated_count == 0
def test_pipeline_batch_processing():
"""Test large batch is processed in chunks."""
config = EmbeddingConfig(
provider='local',
model='test-model',
dimension=16,
batch_size=3 # Small batch size
)
pipeline = EmbeddingPipeline(config)
# 10 texts with batch size 3 = 4 batches
texts = [f"doc {i}" for i in range(10)]
result = pipeline.generate_batch(texts, show_progress=False)
assert len(result.embeddings) == 10
def test_validate_dimensions_valid():
"""Test dimension validation with valid embeddings."""
config = EmbeddingConfig(
provider='local',
model='test-model',
dimension=128
)
pipeline = EmbeddingPipeline(config)
embeddings = [[0.1] * 128, [0.2] * 128]
is_valid = pipeline.validate_dimensions(embeddings)
assert is_valid
def test_validate_dimensions_invalid():
"""Test dimension validation with invalid embeddings."""
config = EmbeddingConfig(
provider='local',
model='test-model',
dimension=128
)
pipeline = EmbeddingPipeline(config)
# Wrong dimension
embeddings = [[0.1] * 64, [0.2] * 128]
is_valid = pipeline.validate_dimensions(embeddings)
assert not is_valid
def test_embedding_result_metadata():
"""Test embedding result includes metadata."""
config = EmbeddingConfig(
provider='local',
model='test-model',
dimension=256
)
pipeline = EmbeddingPipeline(config)
texts = ["test"]
result = pipeline.generate_batch(texts, show_progress=False)
assert 'provider' in result.metadata
assert 'model' in result.metadata
assert 'dimension' in result.metadata
assert result.metadata['dimension'] == 256
def test_cost_stats():
"""Test cost statistics tracking."""
config = EmbeddingConfig(
provider='local',
model='test-model',
dimension=64
)
pipeline = EmbeddingPipeline(config)
texts = ["doc 1", "doc 2"]
pipeline.generate_batch(texts, show_progress=False)
stats = pipeline.get_cost_stats()
assert 'total_requests' in stats
assert 'cache_hits' in stats
assert 'estimated_cost' in stats
def test_empty_batch():
"""Test handling empty batch."""
config = EmbeddingConfig(
provider='local',
model='test-model',
dimension=32
)
pipeline = EmbeddingPipeline(config)
result = pipeline.generate_batch([], show_progress=False)
assert len(result.embeddings) == 0
assert result.generated_count == 0
def test_single_document():
"""Test single document generation."""
config = EmbeddingConfig(
provider='local',
model='test-model',
dimension=128
)
pipeline = EmbeddingPipeline(config)
result = pipeline.generate_batch(["single doc"], show_progress=False)
assert len(result.embeddings) == 1
assert len(result.embeddings[0]) == 128
def test_different_dimensions():
"""Test different embedding dimensions."""
for dim in [64, 128, 256, 512]:
config = EmbeddingConfig(
provider='local',
model='test-model',
dimension=dim
)
pipeline = EmbeddingPipeline(config)
result = pipeline.generate_batch(["test"], show_progress=False)
assert len(result.embeddings[0]) == dim
if __name__ == "__main__":
pytest.main([__file__, "-v"])