feat: Add custom embedding pipeline (Task #17)
- Multi-provider support (OpenAI, Local) - Batch processing with configurable batch size - Memory and disk caching for efficiency - Cost tracking and estimation - Dimension validation - 18 tests passing (100%) Files: - embedding_pipeline.py: Core pipeline engine - test_embedding_pipeline.py: Comprehensive tests Features: - EmbeddingProvider abstraction - OpenAIEmbeddingProvider with pricing - LocalEmbeddingProvider (simulated) - EmbeddingCache (memory + disk) - CostTracker for API usage - Batch processing optimization Supported Models: - text-embedding-ada-002 (1536d, $0.10/1M tokens) - text-embedding-3-small (1536d, $0.02/1M tokens) - text-embedding-3-large (3072d, $0.13/1M tokens) - Local models (any dimension, free) Week 2: 8/9 tasks complete (89%) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
430
src/skill_seekers/cli/embedding_pipeline.py
Normal file
430
src/skill_seekers/cli/embedding_pipeline.py
Normal file
@@ -0,0 +1,430 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Custom Embedding Pipeline
|
||||
|
||||
Provides flexible embedding generation with multiple providers,
|
||||
batch processing, caching, and cost tracking.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingConfig:
|
||||
"""Configuration for embedding generation."""
|
||||
provider: str # 'openai', 'cohere', 'huggingface', 'local'
|
||||
model: str
|
||||
dimension: int
|
||||
batch_size: int = 100
|
||||
cache_dir: Optional[Path] = None
|
||||
max_retries: int = 3
|
||||
retry_delay: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingResult:
|
||||
"""Result of embedding generation."""
|
||||
embeddings: List[List[float]]
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
cached_count: int = 0
|
||||
generated_count: int = 0
|
||||
total_time: float = 0.0
|
||||
cost_estimate: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CostTracker:
|
||||
"""Track embedding generation costs."""
|
||||
total_tokens: int = 0
|
||||
total_requests: int = 0
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
estimated_cost: float = 0.0
|
||||
|
||||
def add_request(self, token_count: int, cost: float, from_cache: bool = False):
|
||||
"""Add a request to tracking."""
|
||||
self.total_requests += 1
|
||||
self.total_tokens += token_count
|
||||
self.estimated_cost += cost
|
||||
|
||||
if from_cache:
|
||||
self.cache_hits += 1
|
||||
else:
|
||||
self.cache_misses += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics."""
|
||||
cache_rate = (self.cache_hits / self.total_requests * 100) if self.total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
'total_requests': self.total_requests,
|
||||
'total_tokens': self.total_tokens,
|
||||
'cache_hits': self.cache_hits,
|
||||
'cache_misses': self.cache_misses,
|
||||
'cache_rate': f"{cache_rate:.1f}%",
|
||||
'estimated_cost': f"${self.estimated_cost:.4f}"
|
||||
}
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Abstract base class for embedding providers."""
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for texts."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dimension(self) -> int:
|
||||
"""Get embedding dimension."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def estimate_cost(self, token_count: int) -> float:
|
||||
"""Estimate cost for token count."""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider."""
|
||||
|
||||
# Pricing per 1M tokens (as of 2026)
|
||||
PRICING = {
|
||||
'text-embedding-ada-002': 0.10,
|
||||
'text-embedding-3-small': 0.02,
|
||||
'text-embedding-3-large': 0.13,
|
||||
}
|
||||
|
||||
DIMENSIONS = {
|
||||
'text-embedding-ada-002': 1536,
|
||||
'text-embedding-3-small': 1536,
|
||||
'text-embedding-3-large': 3072,
|
||||
}
|
||||
|
||||
def __init__(self, model: str = 'text-embedding-ada-002', api_key: Optional[str] = None):
|
||||
"""Initialize OpenAI provider."""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self._client = None
|
||||
|
||||
def _get_client(self):
|
||||
"""Lazy load OpenAI client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from openai import OpenAI
|
||||
self._client = OpenAI(api_key=self.api_key)
|
||||
except ImportError:
|
||||
raise ImportError("OpenAI package not installed. Install with: pip install openai")
|
||||
return self._client
|
||||
|
||||
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings using OpenAI."""
|
||||
client = self._get_client()
|
||||
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = client.embeddings.create(
|
||||
model=self.model,
|
||||
input=text
|
||||
)
|
||||
embeddings.append(response.data[0].embedding)
|
||||
|
||||
return embeddings
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""Get embedding dimension."""
|
||||
return self.DIMENSIONS.get(self.model, 1536)
|
||||
|
||||
def estimate_cost(self, token_count: int) -> float:
|
||||
"""Estimate cost."""
|
||||
price_per_million = self.PRICING.get(self.model, 0.10)
|
||||
return (token_count / 1_000_000) * price_per_million
|
||||
|
||||
|
||||
class LocalEmbeddingProvider(EmbeddingProvider):
|
||||
"""Local embedding provider (simulated)."""
|
||||
|
||||
def __init__(self, dimension: int = 384):
|
||||
"""Initialize local provider."""
|
||||
self.dimension = dimension
|
||||
|
||||
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings using local model (simulated)."""
|
||||
# In production, would use sentence-transformers or similar
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
# Deterministic random based on text hash
|
||||
seed = int(hashlib.md5(text.encode()).hexdigest()[:8], 16)
|
||||
np.random.seed(seed)
|
||||
embedding = np.random.randn(self.dimension).tolist()
|
||||
embeddings.append(embedding)
|
||||
|
||||
return embeddings
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""Get embedding dimension."""
|
||||
return self.dimension
|
||||
|
||||
def estimate_cost(self, token_count: int) -> float:
|
||||
"""Local models are free."""
|
||||
return 0.0
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embeddings to avoid recomputation."""
|
||||
|
||||
def __init__(self, cache_dir: Optional[Path] = None):
|
||||
"""Initialize cache."""
|
||||
self.cache_dir = Path(cache_dir) if cache_dir else None
|
||||
self._memory_cache: Dict[str, List[float]] = {}
|
||||
|
||||
if self.cache_dir:
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _compute_hash(self, text: str, model: str) -> str:
|
||||
"""Compute cache key."""
|
||||
key = f"{model}:{text}"
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
def get(self, text: str, model: str) -> Optional[List[float]]:
|
||||
"""Get embedding from cache."""
|
||||
cache_key = self._compute_hash(text, model)
|
||||
|
||||
# Check memory cache
|
||||
if cache_key in self._memory_cache:
|
||||
return self._memory_cache[cache_key]
|
||||
|
||||
# Check disk cache
|
||||
if self.cache_dir:
|
||||
cache_file = self.cache_dir / f"{cache_key}.json"
|
||||
if cache_file.exists():
|
||||
try:
|
||||
data = json.loads(cache_file.read_text())
|
||||
embedding = data['embedding']
|
||||
self._memory_cache[cache_key] = embedding
|
||||
return embedding
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def set(self, text: str, model: str, embedding: List[float]) -> None:
|
||||
"""Store embedding in cache."""
|
||||
cache_key = self._compute_hash(text, model)
|
||||
|
||||
# Store in memory
|
||||
self._memory_cache[cache_key] = embedding
|
||||
|
||||
# Store on disk
|
||||
if self.cache_dir:
|
||||
cache_file = self.cache_dir / f"{cache_key}.json"
|
||||
try:
|
||||
cache_file.write_text(json.dumps({
|
||||
'text_hash': cache_key,
|
||||
'model': model,
|
||||
'embedding': embedding,
|
||||
'timestamp': time.time()
|
||||
}))
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Failed to write cache: {e}")
|
||||
|
||||
|
||||
class EmbeddingPipeline:
|
||||
"""
|
||||
Flexible embedding generation pipeline.
|
||||
|
||||
Supports multiple providers, batch processing, caching, and cost tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, config: EmbeddingConfig):
|
||||
"""Initialize pipeline."""
|
||||
self.config = config
|
||||
self.provider = self._create_provider()
|
||||
self.cache = EmbeddingCache(config.cache_dir)
|
||||
self.cost_tracker = CostTracker()
|
||||
|
||||
def _create_provider(self) -> EmbeddingProvider:
|
||||
"""Create provider based on config."""
|
||||
if self.config.provider == 'openai':
|
||||
return OpenAIEmbeddingProvider(self.config.model)
|
||||
elif self.config.provider == 'local':
|
||||
return LocalEmbeddingProvider(self.config.dimension)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {self.config.provider}")
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count (rough approximation)."""
|
||||
# Rough estimate: 1 token ≈ 4 characters
|
||||
return len(text) // 4
|
||||
|
||||
def generate_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
show_progress: bool = True
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Generate embeddings for batch of texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
show_progress: Show progress output
|
||||
|
||||
Returns:
|
||||
EmbeddingResult with embeddings and metadata
|
||||
"""
|
||||
start_time = time.time()
|
||||
embeddings = []
|
||||
cached_count = 0
|
||||
generated_count = 0
|
||||
|
||||
if show_progress:
|
||||
print(f"🔄 Generating embeddings...")
|
||||
print(f" Texts: {len(texts)}")
|
||||
print(f" Provider: {self.config.provider}")
|
||||
print(f" Model: {self.config.model}")
|
||||
print(f" Batch size: {self.config.batch_size}")
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(texts), self.config.batch_size):
|
||||
batch = texts[i:i + self.config.batch_size]
|
||||
batch_embeddings = []
|
||||
to_generate = []
|
||||
to_generate_indices = []
|
||||
|
||||
# Check cache
|
||||
for j, text in enumerate(batch):
|
||||
cached = self.cache.get(text, self.config.model)
|
||||
if cached:
|
||||
batch_embeddings.append(cached)
|
||||
cached_count += 1
|
||||
else:
|
||||
to_generate.append(text)
|
||||
to_generate_indices.append(j)
|
||||
|
||||
# Generate missing embeddings
|
||||
if to_generate:
|
||||
new_embeddings = self.provider.generate_embeddings(to_generate)
|
||||
|
||||
# Store in cache
|
||||
for text, embedding in zip(to_generate, new_embeddings):
|
||||
self.cache.set(text, self.config.model, embedding)
|
||||
|
||||
# Track cost
|
||||
total_tokens = sum(self._estimate_tokens(t) for t in to_generate)
|
||||
cost = self.provider.estimate_cost(total_tokens)
|
||||
self.cost_tracker.add_request(total_tokens, cost, from_cache=False)
|
||||
|
||||
# Merge with cached
|
||||
for idx, embedding in zip(to_generate_indices, new_embeddings):
|
||||
batch_embeddings.insert(idx, embedding)
|
||||
|
||||
generated_count += len(to_generate)
|
||||
|
||||
embeddings.extend(batch_embeddings)
|
||||
|
||||
if show_progress and len(texts) > self.config.batch_size:
|
||||
progress = min(i + self.config.batch_size, len(texts))
|
||||
print(f" Progress: {progress}/{len(texts)} ({progress/len(texts)*100:.1f}%)")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
if show_progress:
|
||||
print(f"\n✅ Embeddings generated!")
|
||||
print(f" Total: {len(embeddings)}")
|
||||
print(f" Cached: {cached_count}")
|
||||
print(f" Generated: {generated_count}")
|
||||
print(f" Time: {total_time:.2f}s")
|
||||
|
||||
if self.config.provider != 'local':
|
||||
stats = self.cost_tracker.get_stats()
|
||||
print(f" Cost: {stats['estimated_cost']}")
|
||||
|
||||
return EmbeddingResult(
|
||||
embeddings=embeddings,
|
||||
metadata={
|
||||
'provider': self.config.provider,
|
||||
'model': self.config.model,
|
||||
'dimension': self.provider.get_dimension()
|
||||
},
|
||||
cached_count=cached_count,
|
||||
generated_count=generated_count,
|
||||
total_time=total_time,
|
||||
cost_estimate=self.cost_tracker.estimated_cost
|
||||
)
|
||||
|
||||
def validate_dimensions(self, embeddings: List[List[float]]) -> bool:
|
||||
"""
|
||||
Validate embedding dimensions.
|
||||
|
||||
Args:
|
||||
embeddings: List of embeddings to validate
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
"""
|
||||
expected_dim = self.provider.get_dimension()
|
||||
|
||||
for i, embedding in enumerate(embeddings):
|
||||
if len(embedding) != expected_dim:
|
||||
print(f"❌ Dimension mismatch at index {i}: "
|
||||
f"expected {expected_dim}, got {len(embedding)}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_cost_stats(self) -> Dict[str, Any]:
|
||||
"""Get cost tracking statistics."""
|
||||
return self.cost_tracker.get_stats()
|
||||
|
||||
|
||||
def example_usage():
|
||||
"""Example usage of embedding pipeline."""
|
||||
from pathlib import Path
|
||||
|
||||
# Configure pipeline
|
||||
config = EmbeddingConfig(
|
||||
provider='local', # Use 'openai' for production
|
||||
model='text-embedding-ada-002',
|
||||
dimension=384,
|
||||
batch_size=50,
|
||||
cache_dir=Path("output/.embeddings_cache")
|
||||
)
|
||||
|
||||
# Initialize pipeline
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
# Generate embeddings
|
||||
texts = [
|
||||
"This is the first document.",
|
||||
"Here is the second document.",
|
||||
"And this is the third document.",
|
||||
]
|
||||
|
||||
result = pipeline.generate_batch(texts)
|
||||
|
||||
print(f"\n📊 Results:")
|
||||
print(f" Embeddings: {len(result.embeddings)}")
|
||||
print(f" Dimension: {len(result.embeddings[0])}")
|
||||
print(f" Cached: {result.cached_count}")
|
||||
print(f" Generated: {result.generated_count}")
|
||||
|
||||
# Validate
|
||||
is_valid = pipeline.validate_dimensions(result.embeddings)
|
||||
print(f" Valid: {is_valid}")
|
||||
|
||||
# Cost stats
|
||||
stats = pipeline.get_cost_stats()
|
||||
print(f"\n💰 Cost Stats:")
|
||||
for key, value in stats.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
||||
323
tests/test_embedding_pipeline.py
Normal file
323
tests/test_embedding_pipeline.py
Normal file
@@ -0,0 +1,323 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for custom embedding pipeline.
|
||||
|
||||
Validates:
|
||||
- Multiple provider support
|
||||
- Batch processing
|
||||
- Caching mechanism
|
||||
- Cost tracking
|
||||
- Dimension validation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import tempfile
|
||||
import json
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
from skill_seekers.cli.embedding_pipeline import (
|
||||
EmbeddingConfig,
|
||||
EmbeddingPipeline,
|
||||
LocalEmbeddingProvider,
|
||||
EmbeddingCache,
|
||||
CostTracker
|
||||
)
|
||||
|
||||
|
||||
def test_local_provider_generation():
|
||||
"""Test local embedding provider."""
|
||||
provider = LocalEmbeddingProvider(dimension=128)
|
||||
|
||||
texts = ["test document 1", "test document 2"]
|
||||
embeddings = provider.generate_embeddings(texts)
|
||||
|
||||
assert len(embeddings) == 2
|
||||
assert len(embeddings[0]) == 128
|
||||
assert len(embeddings[1]) == 128
|
||||
|
||||
|
||||
def test_local_provider_deterministic():
|
||||
"""Test local provider generates deterministic embeddings."""
|
||||
provider = LocalEmbeddingProvider(dimension=64)
|
||||
|
||||
text = "same text"
|
||||
emb1 = provider.generate_embeddings([text])[0]
|
||||
emb2 = provider.generate_embeddings([text])[0]
|
||||
|
||||
# Should be identical for same text
|
||||
assert emb1 == emb2
|
||||
|
||||
|
||||
def test_local_provider_cost():
|
||||
"""Test local provider cost estimation."""
|
||||
provider = LocalEmbeddingProvider()
|
||||
|
||||
cost = provider.estimate_cost(1000)
|
||||
assert cost == 0.0 # Local is free
|
||||
|
||||
|
||||
def test_cache_memory():
|
||||
"""Test memory cache functionality."""
|
||||
cache = EmbeddingCache()
|
||||
|
||||
text = "test text"
|
||||
model = "test-model"
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
# Set and get
|
||||
cache.set(text, model, embedding)
|
||||
retrieved = cache.get(text, model)
|
||||
|
||||
assert retrieved == embedding
|
||||
|
||||
|
||||
def test_cache_miss():
|
||||
"""Test cache miss returns None."""
|
||||
cache = EmbeddingCache()
|
||||
|
||||
result = cache.get("nonexistent", "model")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_cache_disk():
|
||||
"""Test disk cache functionality."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = EmbeddingCache(cache_dir=Path(tmpdir))
|
||||
|
||||
text = "test text"
|
||||
model = "test-model"
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
# Set
|
||||
cache.set(text, model, embedding)
|
||||
|
||||
# Create new cache instance (clears memory)
|
||||
cache2 = EmbeddingCache(cache_dir=Path(tmpdir))
|
||||
|
||||
# Should retrieve from disk
|
||||
retrieved = cache2.get(text, model)
|
||||
assert retrieved == embedding
|
||||
|
||||
|
||||
def test_cost_tracker():
|
||||
"""Test cost tracking."""
|
||||
tracker = CostTracker()
|
||||
|
||||
# Add requests
|
||||
tracker.add_request(token_count=1000, cost=0.01, from_cache=False)
|
||||
tracker.add_request(token_count=500, cost=0.005, from_cache=True)
|
||||
|
||||
stats = tracker.get_stats()
|
||||
|
||||
assert stats['total_requests'] == 2
|
||||
assert stats['total_tokens'] == 1500
|
||||
assert stats['cache_hits'] == 1
|
||||
assert stats['cache_misses'] == 1
|
||||
assert '50.0%' in stats['cache_rate']
|
||||
|
||||
|
||||
def test_pipeline_initialization():
|
||||
"""Test pipeline initialization."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=128,
|
||||
batch_size=10
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
assert pipeline.config == config
|
||||
assert pipeline.provider is not None
|
||||
assert pipeline.cache is not None
|
||||
|
||||
|
||||
def test_pipeline_generate_batch():
|
||||
"""Test batch embedding generation."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=64,
|
||||
batch_size=2
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
texts = ["doc 1", "doc 2", "doc 3"]
|
||||
result = pipeline.generate_batch(texts, show_progress=False)
|
||||
|
||||
assert len(result.embeddings) == 3
|
||||
assert len(result.embeddings[0]) == 64
|
||||
assert result.generated_count == 3
|
||||
assert result.cached_count == 0
|
||||
|
||||
|
||||
def test_pipeline_caching():
|
||||
"""Test pipeline uses caching."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=32,
|
||||
batch_size=10,
|
||||
cache_dir=Path(tmpdir)
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
texts = ["same doc", "same doc", "different doc"]
|
||||
|
||||
# First generation
|
||||
result1 = pipeline.generate_batch(texts, show_progress=False)
|
||||
assert result1.cached_count == 0
|
||||
assert result1.generated_count == 3
|
||||
|
||||
# Second generation (should use cache)
|
||||
result2 = pipeline.generate_batch(texts, show_progress=False)
|
||||
assert result2.cached_count == 3
|
||||
assert result2.generated_count == 0
|
||||
|
||||
|
||||
def test_pipeline_batch_processing():
|
||||
"""Test large batch is processed in chunks."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=16,
|
||||
batch_size=3 # Small batch size
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
# 10 texts with batch size 3 = 4 batches
|
||||
texts = [f"doc {i}" for i in range(10)]
|
||||
result = pipeline.generate_batch(texts, show_progress=False)
|
||||
|
||||
assert len(result.embeddings) == 10
|
||||
|
||||
|
||||
def test_validate_dimensions_valid():
|
||||
"""Test dimension validation with valid embeddings."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=128
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
embeddings = [[0.1] * 128, [0.2] * 128]
|
||||
is_valid = pipeline.validate_dimensions(embeddings)
|
||||
|
||||
assert is_valid
|
||||
|
||||
|
||||
def test_validate_dimensions_invalid():
|
||||
"""Test dimension validation with invalid embeddings."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=128
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
# Wrong dimension
|
||||
embeddings = [[0.1] * 64, [0.2] * 128]
|
||||
is_valid = pipeline.validate_dimensions(embeddings)
|
||||
|
||||
assert not is_valid
|
||||
|
||||
|
||||
def test_embedding_result_metadata():
|
||||
"""Test embedding result includes metadata."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=256
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
texts = ["test"]
|
||||
result = pipeline.generate_batch(texts, show_progress=False)
|
||||
|
||||
assert 'provider' in result.metadata
|
||||
assert 'model' in result.metadata
|
||||
assert 'dimension' in result.metadata
|
||||
assert result.metadata['dimension'] == 256
|
||||
|
||||
|
||||
def test_cost_stats():
|
||||
"""Test cost statistics tracking."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=64
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
texts = ["doc 1", "doc 2"]
|
||||
pipeline.generate_batch(texts, show_progress=False)
|
||||
|
||||
stats = pipeline.get_cost_stats()
|
||||
|
||||
assert 'total_requests' in stats
|
||||
assert 'cache_hits' in stats
|
||||
assert 'estimated_cost' in stats
|
||||
|
||||
|
||||
def test_empty_batch():
|
||||
"""Test handling empty batch."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=32
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
result = pipeline.generate_batch([], show_progress=False)
|
||||
|
||||
assert len(result.embeddings) == 0
|
||||
assert result.generated_count == 0
|
||||
|
||||
|
||||
def test_single_document():
|
||||
"""Test single document generation."""
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=128
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
|
||||
result = pipeline.generate_batch(["single doc"], show_progress=False)
|
||||
|
||||
assert len(result.embeddings) == 1
|
||||
assert len(result.embeddings[0]) == 128
|
||||
|
||||
|
||||
def test_different_dimensions():
|
||||
"""Test different embedding dimensions."""
|
||||
for dim in [64, 128, 256, 512]:
|
||||
config = EmbeddingConfig(
|
||||
provider='local',
|
||||
model='test-model',
|
||||
dimension=dim
|
||||
)
|
||||
|
||||
pipeline = EmbeddingPipeline(config)
|
||||
result = pipeline.generate_batch(["test"], show_progress=False)
|
||||
|
||||
assert len(result.embeddings[0]) == dim
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user