fix: Enforce min_chunk_size in RAG chunker
- Filter out chunks smaller than min_chunk_size (default 100 tokens) - Exception: Keep all chunks if entire document is smaller than target size - All 15 tests passing (100% pass rate) Fixes edge case where very small chunks (e.g., 'Short.' = 6 chars) were being created despite min_chunk_size=100 setting. Test: pytest tests/test_rag_chunker.py -v
This commit is contained in:
369
tests/test_embedding.py
Normal file
369
tests/test_embedding.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""
|
||||
Tests for embedding generation system.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from skill_seekers.embedding.models import (
|
||||
EmbeddingRequest,
|
||||
BatchEmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
BatchEmbeddingResponse,
|
||||
HealthResponse,
|
||||
ModelInfo,
|
||||
)
|
||||
from skill_seekers.embedding.generator import EmbeddingGenerator
|
||||
from skill_seekers.embedding.cache import EmbeddingCache
|
||||
|
||||
|
||||
# ========================================
|
||||
# Cache Tests
|
||||
# ========================================
|
||||
|
||||
def test_cache_init():
|
||||
"""Test cache initialization."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
assert cache.size() == 0
|
||||
|
||||
|
||||
def test_cache_set_get():
|
||||
"""Test cache set and get."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
cache.set("hash123", embedding, "test-model")
|
||||
|
||||
retrieved = cache.get("hash123")
|
||||
assert retrieved == embedding
|
||||
|
||||
|
||||
def test_cache_has():
|
||||
"""Test cache has method."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
cache.set("hash123", embedding, "test-model")
|
||||
|
||||
assert cache.has("hash123") is True
|
||||
assert cache.has("nonexistent") is False
|
||||
|
||||
|
||||
def test_cache_delete():
|
||||
"""Test cache deletion."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
cache.set("hash123", embedding, "test-model")
|
||||
|
||||
assert cache.has("hash123") is True
|
||||
|
||||
cache.delete("hash123")
|
||||
|
||||
assert cache.has("hash123") is False
|
||||
|
||||
|
||||
def test_cache_clear():
|
||||
"""Test cache clearing."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
|
||||
cache.set("hash1", [0.1], "model1")
|
||||
cache.set("hash2", [0.2], "model2")
|
||||
cache.set("hash3", [0.3], "model1")
|
||||
|
||||
assert cache.size() == 3
|
||||
|
||||
# Clear specific model
|
||||
deleted = cache.clear(model="model1")
|
||||
assert deleted == 2
|
||||
assert cache.size() == 1
|
||||
|
||||
# Clear all
|
||||
deleted = cache.clear()
|
||||
assert deleted == 1
|
||||
assert cache.size() == 0
|
||||
|
||||
|
||||
def test_cache_stats():
|
||||
"""Test cache statistics."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
|
||||
cache.set("hash1", [0.1], "model1")
|
||||
cache.set("hash2", [0.2], "model2")
|
||||
cache.set("hash3", [0.3], "model1")
|
||||
|
||||
stats = cache.stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["by_model"]["model1"] == 2
|
||||
assert stats["by_model"]["model2"] == 1
|
||||
|
||||
|
||||
def test_cache_context_manager():
|
||||
"""Test cache as context manager."""
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
with EmbeddingCache(tmp_path) as cache:
|
||||
cache.set("hash1", [0.1], "model1")
|
||||
assert cache.size() == 1
|
||||
|
||||
# Verify database file exists
|
||||
assert Path(tmp_path).exists()
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
# ========================================
|
||||
# Generator Tests
|
||||
# ========================================
|
||||
|
||||
def test_generator_init():
|
||||
"""Test generator initialization."""
|
||||
generator = EmbeddingGenerator()
|
||||
assert generator is not None
|
||||
|
||||
|
||||
def test_generator_list_models():
|
||||
"""Test listing models."""
|
||||
generator = EmbeddingGenerator()
|
||||
models = generator.list_models()
|
||||
|
||||
assert len(models) > 0
|
||||
assert all("name" in m for m in models)
|
||||
assert all("provider" in m for m in models)
|
||||
assert all("dimensions" in m for m in models)
|
||||
|
||||
|
||||
def test_generator_get_model_info():
|
||||
"""Test getting model info."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
info = generator.get_model_info("text-embedding-3-small")
|
||||
|
||||
assert info["provider"] == "openai"
|
||||
assert info["dimensions"] == 1536
|
||||
assert info["max_tokens"] == 8191
|
||||
|
||||
|
||||
def test_generator_get_model_info_invalid():
|
||||
"""Test getting model info for invalid model."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown model"):
|
||||
generator.get_model_info("nonexistent-model")
|
||||
|
||||
|
||||
def test_generator_compute_hash():
|
||||
"""Test hash computation."""
|
||||
hash1 = EmbeddingGenerator.compute_hash("text1", "model1")
|
||||
hash2 = EmbeddingGenerator.compute_hash("text1", "model1")
|
||||
hash3 = EmbeddingGenerator.compute_hash("text2", "model1")
|
||||
hash4 = EmbeddingGenerator.compute_hash("text1", "model2")
|
||||
|
||||
# Same text+model = same hash
|
||||
assert hash1 == hash2
|
||||
|
||||
# Different text = different hash
|
||||
assert hash1 != hash3
|
||||
|
||||
# Different model = different hash
|
||||
assert hash1 != hash4
|
||||
|
||||
|
||||
@patch('skill_seekers.embedding.generator.SENTENCE_TRANSFORMERS_AVAILABLE', False)
|
||||
def test_generator_sentence_transformers_not_available():
|
||||
"""Test sentence-transformers not available."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
with pytest.raises(ImportError, match="sentence-transformers is required"):
|
||||
generator.generate("test", model="all-MiniLM-L6-v2")
|
||||
|
||||
|
||||
@patch('skill_seekers.embedding.generator.OPENAI_AVAILABLE', False)
|
||||
def test_generator_openai_not_available():
|
||||
"""Test OpenAI not available."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
with pytest.raises(ImportError, match="OpenAI is required"):
|
||||
generator.generate("test", model="text-embedding-3-small")
|
||||
|
||||
|
||||
@patch('skill_seekers.embedding.generator.VOYAGE_AVAILABLE', False)
|
||||
def test_generator_voyage_not_available():
|
||||
"""Test Voyage AI not available."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
with pytest.raises(ImportError, match="voyageai is required"):
|
||||
generator.generate("test", model="voyage-3")
|
||||
|
||||
|
||||
def test_generator_voyage_model_info():
|
||||
"""Test getting Voyage AI model info."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
info = generator.get_model_info("voyage-3")
|
||||
|
||||
assert info["provider"] == "voyage"
|
||||
assert info["dimensions"] == 1024
|
||||
assert info["max_tokens"] == 32000
|
||||
|
||||
|
||||
def test_generator_voyage_large_2_model_info():
|
||||
"""Test getting Voyage Large 2 model info."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
info = generator.get_model_info("voyage-large-2")
|
||||
|
||||
assert info["provider"] == "voyage"
|
||||
assert info["dimensions"] == 1536
|
||||
assert info["cost_per_million"] == 0.12
|
||||
|
||||
|
||||
# ========================================
|
||||
# Model Tests
|
||||
# ========================================
|
||||
|
||||
def test_embedding_request():
|
||||
"""Test EmbeddingRequest model."""
|
||||
request = EmbeddingRequest(
|
||||
text="Hello world",
|
||||
model="text-embedding-3-small",
|
||||
normalize=True
|
||||
)
|
||||
|
||||
assert request.text == "Hello world"
|
||||
assert request.model == "text-embedding-3-small"
|
||||
assert request.normalize is True
|
||||
|
||||
|
||||
def test_batch_embedding_request():
|
||||
"""Test BatchEmbeddingRequest model."""
|
||||
request = BatchEmbeddingRequest(
|
||||
texts=["text1", "text2", "text3"],
|
||||
model="text-embedding-3-small",
|
||||
batch_size=32
|
||||
)
|
||||
|
||||
assert len(request.texts) == 3
|
||||
assert request.batch_size == 32
|
||||
|
||||
|
||||
def test_embedding_response():
|
||||
"""Test EmbeddingResponse model."""
|
||||
response = EmbeddingResponse(
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
model="test-model",
|
||||
dimensions=3,
|
||||
cached=False
|
||||
)
|
||||
|
||||
assert len(response.embedding) == 3
|
||||
assert response.dimensions == 3
|
||||
assert response.cached is False
|
||||
|
||||
|
||||
def test_batch_embedding_response():
|
||||
"""Test BatchEmbeddingResponse model."""
|
||||
response = BatchEmbeddingResponse(
|
||||
embeddings=[[0.1, 0.2], [0.3, 0.4]],
|
||||
model="test-model",
|
||||
dimensions=2,
|
||||
count=2,
|
||||
cached_count=1
|
||||
)
|
||||
|
||||
assert len(response.embeddings) == 2
|
||||
assert response.count == 2
|
||||
assert response.cached_count == 1
|
||||
|
||||
|
||||
def test_health_response():
|
||||
"""Test HealthResponse model."""
|
||||
response = HealthResponse(
|
||||
status="ok",
|
||||
version="1.0.0",
|
||||
models=["model1", "model2"],
|
||||
cache_enabled=True,
|
||||
cache_size=100
|
||||
)
|
||||
|
||||
assert response.status == "ok"
|
||||
assert len(response.models) == 2
|
||||
assert response.cache_size == 100
|
||||
|
||||
|
||||
def test_model_info():
|
||||
"""Test ModelInfo model."""
|
||||
info = ModelInfo(
|
||||
name="test-model",
|
||||
provider="openai",
|
||||
dimensions=1536,
|
||||
max_tokens=8191,
|
||||
cost_per_million=0.02
|
||||
)
|
||||
|
||||
assert info.name == "test-model"
|
||||
assert info.provider == "openai"
|
||||
assert info.cost_per_million == 0.02
|
||||
|
||||
|
||||
# ========================================
|
||||
# Integration Tests
|
||||
# ========================================
|
||||
|
||||
def test_cache_batch_operations():
|
||||
"""Test cache batch operations."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
|
||||
# Set multiple embeddings
|
||||
cache.set("hash1", [0.1, 0.2], "model1")
|
||||
cache.set("hash2", [0.3, 0.4], "model1")
|
||||
cache.set("hash3", [0.5, 0.6], "model1")
|
||||
|
||||
# Get batch
|
||||
embeddings, cached_flags = cache.get_batch(["hash1", "hash2", "hash999", "hash3"])
|
||||
|
||||
assert len(embeddings) == 4
|
||||
assert embeddings[0] == [0.1, 0.2]
|
||||
assert embeddings[1] == [0.3, 0.4]
|
||||
assert embeddings[2] is None # Cache miss
|
||||
assert embeddings[3] == [0.5, 0.6]
|
||||
|
||||
assert cached_flags == [True, True, False, True]
|
||||
|
||||
|
||||
def test_generator_normalize():
|
||||
"""Test embedding normalization."""
|
||||
import numpy as np
|
||||
|
||||
embedding = [3.0, 4.0] # Length 5
|
||||
normalized = EmbeddingGenerator._normalize(embedding)
|
||||
|
||||
# Check unit length
|
||||
length = np.linalg.norm(normalized)
|
||||
assert abs(length - 1.0) < 1e-6
|
||||
|
||||
|
||||
def test_cache_persistence():
|
||||
"""Test cache persistence to file."""
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# Create cache and add data
|
||||
cache1 = EmbeddingCache(tmp_path)
|
||||
cache1.set("hash1", [0.1, 0.2, 0.3], "model1")
|
||||
cache1.close()
|
||||
|
||||
# Reopen cache and verify data persists
|
||||
cache2 = EmbeddingCache(tmp_path)
|
||||
retrieved = cache2.get("hash1")
|
||||
assert retrieved == [0.1, 0.2, 0.3]
|
||||
cache2.close()
|
||||
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
Reference in New Issue
Block a user