feat: Add custom embedding pipeline (Task #17)
- Multi-provider support (OpenAI, Local) - Batch processing with configurable batch size - Memory and disk caching for efficiency - Cost tracking and estimation - Dimension validation - 18 tests passing (100%) Files: - embedding_pipeline.py: Core pipeline engine - test_embedding_pipeline.py: Comprehensive tests Features: - EmbeddingProvider abstraction - OpenAIEmbeddingProvider with pricing - LocalEmbeddingProvider (simulated) - EmbeddingCache (memory + disk) - CostTracker for API usage - Batch processing optimization Supported Models: - text-embedding-ada-002 (1536d, $0.10/1M tokens) - text-embedding-3-small (1536d, $0.02/1M tokens) - text-embedding-3-large (3072d, $0.13/1M tokens) - Local models (any dimension, free) Week 2: 8/9 tasks complete (89%) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
323
tests/test_embedding_pipeline.py
Normal file
323
tests/test_embedding_pipeline.py
Normal file
@@ -0,0 +1,323 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for custom embedding pipeline.
|
||||
|
||||
Validates:
|
||||
- Multiple provider support
|
||||
- Batch processing
|
||||
- Caching mechanism
|
||||
- Cost tracking
|
||||
- Dimension validation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import tempfile
|
||||
import json
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
from skill_seekers.cli.embedding_pipeline import (
|
||||
EmbeddingConfig,
|
||||
EmbeddingPipeline,
|
||||
LocalEmbeddingProvider,
|
||||
EmbeddingCache,
|
||||
CostTracker
|
||||
)
|
||||
|
||||
|
||||
def test_local_provider_generation():
|
||||
"""Test local embedding provider."""
|
||||
provider = LocalEmbeddingProvider(dimension=128)
|
||||
|
||||
texts = ["test document 1", "test document 2"]
|
||||
embeddings = provider.generate_embeddings(texts)
|
||||
|
||||
assert len(embeddings) == 2
|
||||
assert len(embeddings[0]) == 128
|
||||
assert len(embeddings[1]) == 128
|
||||
|
||||
|
||||
def test_local_provider_deterministic():
|
||||
"""Test local provider generates deterministic embeddings."""
|
||||
provider = LocalEmbeddingProvider(dimension=64)
|
||||
|
||||
text = "same text"
|
||||
emb1 = provider.generate_embeddings([text])[0]
|
||||
emb2 = provider.generate_embeddings([text])[0]
|
||||
|
||||
# Should be identical for same text
|
||||
assert emb1 == emb2
|
||||
|
||||
|
||||
def test_local_provider_cost():
|
||||
"""Test local provider cost estimation."""
|
||||
provider = LocalEmbeddingProvider()
|
||||
|
||||
cost = provider.estimate_cost(1000)
|
||||
assert cost == 0.0 # Local is free
|
||||
|
||||
|
||||
def test_cache_memory():
|
||||
"""Test memory cache functionality."""
|
||||
cache = EmbeddingCache()
|
||||
|
||||
text = "test text"
|
||||
model = "test-model"
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
# Set and get
|
||||
cache.set(text, model, embedding)
|
||||
retrieved = cache.get(text, model)
|
||||
|
||||
assert retrieved == embedding
|
||||
|
||||
|
||||
def test_cache_miss():
|
||||
"""Test cache miss returns None."""
|
||||
cache = EmbeddingCache()
|
||||
|
||||
result = cache.get("nonexistent", "model")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_cache_disk():
|
||||
"""Test disk cache functionality."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = EmbeddingCache(cache_dir=Path(tmpdir))
|
||||
|
||||
text = "test text"
|
||||
model = "test-model"
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
# Set
|
||||
cache.set(text, model, embedding)
|
||||
|
||||
# Create new cache instance (clears memory)
|
||||
cache2 = EmbeddingCache(cache_dir=Path(tmpdir))
|
||||
|
||||
# Should retrieve from disk
|
||||
retrieved = cache2.get(text, model)
|
||||
assert retrieved == embedding
|
||||
|
||||
|
||||
def test_cost_tracker():
|
||||
"""Test cost tracking."""
|
||||
tracker = CostTracker()
|
||||
|
||||
# Add requests
|
||||
tracker.add_request(token_count=1000, cost=0.01, from_cache=False)
|
||||
tracker.add_request(token_count=500, cost=0.005, from_cache=True)
|
||||
|
||||
stats = tracker.get_stats()
|
||||
|
||||
assert stats['total_requests'] == 2
|
||||
assert stats['total_tokens'] == 1500
|
||||
assert stats['cache_hits'] == 1
|
||||
assert stats['cache_misses'] == 1
|
||||
assert '50.0%' in stats['cache_rate']
|
||||
|
||||
|
||||
def test_pipeline_initialization():
|
||||
"""Test pipeline initialization."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=128,
|
||||
batch_size=10
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
assert pipeline.config == config
|
||||
assert pipeline.provider is not None
|
||||
assert pipeline.cache is not None
|
||||
|
||||
|
||||
def test_pipeline_generate_batch():
|
||||
"""Test batch embedding generation."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=64,
|
||||
batch_size=2
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
texts = ["doc 1", "doc 2", "doc 3"]
|
||||
result = pipeline.generate_batch(texts, show_progress=False)
|
||||
|
||||
assert len(result.embeddings) == 3
|
||||
assert len(result.embeddings[0]) == 64
|
||||
assert result.generated_count == 3
|
||||
assert result.cached_count == 0
|
||||
|
||||
|
||||
def test_pipeline_caching():
|
||||
"""Test pipeline uses caching."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=32,
|
||||
batch_size=10,
|
||||
cache_dir=Path(tmpdir)
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
texts = ["same doc", "same doc", "different doc"]
|
||||
|
||||
# First generation
|
||||
result1 = pipeline.generate_batch(texts, show_progress=False)
|
||||
assert result1.cached_count == 0
|
||||
assert result1.generated_count == 3
|
||||
|
||||
# Second generation (should use cache)
|
||||
result2 = pipeline.generate_batch(texts, show_progress=False)
|
||||
assert result2.cached_count == 3
|
||||
assert result2.generated_count == 0
|
||||
|
||||
|
||||
def test_pipeline_batch_processing():
|
||||
"""Test large batch is processed in chunks."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=16,
|
||||
batch_size=3 # Small batch size
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
# 10 texts with batch size 3 = 4 batches
|
||||
texts = [f"doc {i}" for i in range(10)]
|
||||
result = pipeline.generate_batch(texts, show_progress=False)
|
||||
|
||||
assert len(result.embeddings) == 10
|
||||
|
||||
|
||||
def test_validate_dimensions_valid():
|
||||
"""Test dimension validation with valid embeddings."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=128
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
embeddings = [[0.1] * 128, [0.2] * 128]
|
||||
is_valid = pipeline.validate_dimensions(embeddings)
|
||||
|
||||
assert is_valid
|
||||
|
||||
|
||||
def test_validate_dimensions_invalid():
|
||||
"""Test dimension validation with invalid embeddings."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=128
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
# Wrong dimension
|
||||
embeddings = [[0.1] * 64, [0.2] * 128]
|
||||
is_valid = pipeline.validate_dimensions(embeddings)
|
||||
|
||||
assert not is_valid
|
||||
|
||||
|
||||
def test_embedding_result_metadata():
|
||||
"""Test embedding result includes metadata."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=256
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
texts = ["test"]
|
||||
result = pipeline.generate_batch(texts, show_progress=False)
|
||||
|
||||
assert 'provider' in result.metadata
|
||||
assert 'model' in result.metadata
|
||||
assert 'dimension' in result.metadata
|
||||
assert result.metadata['dimension'] == 256
|
||||
|
||||
|
||||
def test_cost_stats():
|
||||
"""Test cost statistics tracking."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=64
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
texts = ["doc 1", "doc 2"]
|
||||
pipeline.generate_batch(texts, show_progress=False)
|
||||
|
||||
stats = pipeline.get_cost_stats()
|
||||
|
||||
assert 'total_requests' in stats
|
||||
assert 'cache_hits' in stats
|
||||
assert 'estimated_cost' in stats
|
||||
|
||||
|
||||
def test_empty_batch():
|
||||
"""Test handling empty batch."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=32
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
result = pipeline.generate_batch([], show_progress=False)
|
||||
|
||||
assert len(result.embeddings) == 0
|
||||
assert result.generated_count == 0
|
||||
|
||||
|
||||
def test_single_document():
|
||||
"""Test single document generation."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=128
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
result = pipeline.generate_batch(["single doc"], show_progress=False)
|
||||
|
||||
assert len(result.embeddings) == 1
|
||||
assert len(result.embeddings[0]) == 128
|
||||
|
||||
|
||||
def test_different_dimensions():
|
||||
"""Test different embedding dimensions."""
|
||||
for dim in [64, 128, 256, 512]:
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=dim
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
result = pipeline.generate_batch(["test"], show_progress=False)
|
||||
|
||||
assert len(result.embeddings[0]) == dim
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user