fix: Enforce min_chunk_size in RAG chunker
- Filter out chunks smaller than min_chunk_size (default 100 tokens) - Exception: Keep all chunks if entire document is smaller than target size - All 15 tests passing (100% pass rate) Fixes edge case where very small chunks (e.g., 'Short.' = 6 chars) were being created despite min_chunk_size=100 setting. Test: pytest tests/test_rag_chunker.py -v
This commit is contained in:
665
tests/test_benchmark.py
Normal file
665
tests/test_benchmark.py
Normal file
@@ -0,0 +1,665 @@
|
||||
"""
|
||||
Tests for benchmarking suite.
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from skill_seekers.benchmark import (
|
||||
Benchmark,
|
||||
BenchmarkResult,
|
||||
BenchmarkRunner,
|
||||
BenchmarkReport,
|
||||
Metric
|
||||
)
|
||||
from skill_seekers.benchmark.models import TimingResult, MemoryUsage
|
||||
|
||||
|
||||
class TestBenchmarkResult:
|
||||
"""Test BenchmarkResult class."""
|
||||
|
||||
def test_result_initialization(self):
|
||||
"""Test result initialization."""
|
||||
result = BenchmarkResult("test-benchmark")
|
||||
|
||||
assert result.name == "test-benchmark"
|
||||
assert isinstance(result.started_at, datetime)
|
||||
assert result.finished_at is None
|
||||
assert result.timings == []
|
||||
assert result.memory == []
|
||||
assert result.metrics == []
|
||||
assert result.system_info == {}
|
||||
assert result.recommendations == []
|
||||
|
||||
def test_add_timing(self):
|
||||
"""Test adding timing result."""
|
||||
result = BenchmarkResult("test")
|
||||
|
||||
timing = TimingResult(
|
||||
operation="test_op",
|
||||
duration=1.5,
|
||||
iterations=1,
|
||||
avg_duration=1.5
|
||||
)
|
||||
|
||||
result.add_timing(timing)
|
||||
|
||||
assert len(result.timings) == 1
|
||||
assert result.timings[0].operation == "test_op"
|
||||
assert result.timings[0].duration == 1.5
|
||||
|
||||
def test_add_memory(self):
|
||||
"""Test adding memory usage."""
|
||||
result = BenchmarkResult("test")
|
||||
|
||||
usage = MemoryUsage(
|
||||
operation="test_op",
|
||||
before_mb=100.0,
|
||||
after_mb=150.0,
|
||||
peak_mb=160.0,
|
||||
allocated_mb=50.0
|
||||
)
|
||||
|
||||
result.add_memory(usage)
|
||||
|
||||
assert len(result.memory) == 1
|
||||
assert result.memory[0].operation == "test_op"
|
||||
assert result.memory[0].allocated_mb == 50.0
|
||||
|
||||
def test_add_metric(self):
|
||||
"""Test adding custom metric."""
|
||||
result = BenchmarkResult("test")
|
||||
|
||||
metric = Metric(
|
||||
name="pages_per_sec",
|
||||
value=12.5,
|
||||
unit="pages/sec"
|
||||
)
|
||||
|
||||
result.add_metric(metric)
|
||||
|
||||
assert len(result.metrics) == 1
|
||||
assert result.metrics[0].name == "pages_per_sec"
|
||||
assert result.metrics[0].value == 12.5
|
||||
|
||||
def test_add_recommendation(self):
|
||||
"""Test adding recommendation."""
|
||||
result = BenchmarkResult("test")
|
||||
|
||||
result.add_recommendation("Consider caching")
|
||||
|
||||
assert len(result.recommendations) == 1
|
||||
assert result.recommendations[0] == "Consider caching"
|
||||
|
||||
def test_set_system_info(self):
|
||||
"""Test collecting system info."""
|
||||
result = BenchmarkResult("test")
|
||||
|
||||
result.set_system_info()
|
||||
|
||||
assert "cpu_count" in result.system_info
|
||||
assert "memory_total_gb" in result.system_info
|
||||
assert result.system_info["cpu_count"] > 0
|
||||
|
||||
def test_to_report(self):
|
||||
"""Test report generation."""
|
||||
result = BenchmarkResult("test")
|
||||
|
||||
timing = TimingResult(
|
||||
operation="test_op",
|
||||
duration=1.0,
|
||||
iterations=1,
|
||||
avg_duration=1.0
|
||||
)
|
||||
result.add_timing(timing)
|
||||
|
||||
report = result.to_report()
|
||||
|
||||
assert isinstance(report, BenchmarkReport)
|
||||
assert report.name == "test"
|
||||
assert report.finished_at is not None
|
||||
assert len(report.timings) == 1
|
||||
assert report.total_duration > 0
|
||||
|
||||
|
||||
class TestBenchmark:
|
||||
"""Test Benchmark class."""
|
||||
|
||||
def test_benchmark_initialization(self):
|
||||
"""Test benchmark initialization."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
assert benchmark.name == "test"
|
||||
assert isinstance(benchmark.result, BenchmarkResult)
|
||||
|
||||
def test_timer_context_manager(self):
|
||||
"""Test timer context manager."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
with benchmark.timer("operation"):
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(benchmark.result.timings) == 1
|
||||
assert benchmark.result.timings[0].operation == "operation"
|
||||
assert benchmark.result.timings[0].duration >= 0.1
|
||||
|
||||
def test_timer_with_iterations(self):
|
||||
"""Test timer with iterations."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
with benchmark.timer("operation", iterations=5):
|
||||
time.sleep(0.05)
|
||||
|
||||
timing = benchmark.result.timings[0]
|
||||
assert timing.iterations == 5
|
||||
assert timing.avg_duration < timing.duration
|
||||
|
||||
def test_memory_context_manager(self):
|
||||
"""Test memory context manager."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
with benchmark.memory("operation"):
|
||||
# Allocate some memory
|
||||
data = [0] * 1000000
|
||||
|
||||
assert len(benchmark.result.memory) == 1
|
||||
assert benchmark.result.memory[0].operation == "operation"
|
||||
assert benchmark.result.memory[0].allocated_mb >= 0
|
||||
|
||||
def test_measure_function(self):
|
||||
"""Test measure function."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
def slow_function(x):
|
||||
time.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
result = benchmark.measure(slow_function, 5, operation="multiply")
|
||||
|
||||
assert result == 10
|
||||
assert len(benchmark.result.timings) == 1
|
||||
assert benchmark.result.timings[0].operation == "multiply"
|
||||
|
||||
def test_measure_with_memory_tracking(self):
|
||||
"""Test measure with memory tracking."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
def allocate_memory():
|
||||
return [0] * 1000000
|
||||
|
||||
benchmark.measure(allocate_memory, operation="allocate", track_memory=True)
|
||||
|
||||
assert len(benchmark.result.timings) == 1
|
||||
assert len(benchmark.result.memory) == 1
|
||||
|
||||
def test_timed_decorator(self):
|
||||
"""Test timed decorator."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
@benchmark.timed("decorated_func")
|
||||
def my_function(x):
|
||||
time.sleep(0.05)
|
||||
return x + 1
|
||||
|
||||
result = my_function(5)
|
||||
|
||||
assert result == 6
|
||||
assert len(benchmark.result.timings) == 1
|
||||
assert benchmark.result.timings[0].operation == "decorated_func"
|
||||
|
||||
def test_timed_decorator_with_memory(self):
|
||||
"""Test timed decorator with memory tracking."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
@benchmark.timed("memory_func", track_memory=True)
|
||||
def allocate():
|
||||
return [0] * 1000000
|
||||
|
||||
allocate()
|
||||
|
||||
assert len(benchmark.result.timings) == 1
|
||||
assert len(benchmark.result.memory) == 1
|
||||
|
||||
def test_metric_recording(self):
|
||||
"""Test metric recording."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
benchmark.metric("throughput", 125.5, "ops/sec")
|
||||
|
||||
assert len(benchmark.result.metrics) == 1
|
||||
assert benchmark.result.metrics[0].name == "throughput"
|
||||
assert benchmark.result.metrics[0].value == 125.5
|
||||
|
||||
def test_recommendation_recording(self):
|
||||
"""Test recommendation recording."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
benchmark.recommend("Use batch processing")
|
||||
|
||||
assert len(benchmark.result.recommendations) == 1
|
||||
assert "batch" in benchmark.result.recommendations[0].lower()
|
||||
|
||||
def test_report_generation(self):
|
||||
"""Test report generation."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
with benchmark.timer("op1"):
|
||||
time.sleep(0.05)
|
||||
|
||||
benchmark.metric("count", 10, "items")
|
||||
|
||||
report = benchmark.report()
|
||||
|
||||
assert isinstance(report, BenchmarkReport)
|
||||
assert report.name == "test"
|
||||
assert len(report.timings) == 1
|
||||
assert len(report.metrics) == 1
|
||||
|
||||
def test_save_report(self, tmp_path):
|
||||
"""Test saving report to file."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
with benchmark.timer("operation"):
|
||||
time.sleep(0.05)
|
||||
|
||||
output_path = tmp_path / "benchmark.json"
|
||||
benchmark.save(output_path)
|
||||
|
||||
assert output_path.exists()
|
||||
|
||||
# Verify contents
|
||||
with open(output_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert data["name"] == "test"
|
||||
assert len(data["timings"]) == 1
|
||||
|
||||
def test_analyze_bottlenecks(self):
|
||||
"""Test bottleneck analysis."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
# Create operations with different durations
|
||||
with benchmark.timer("fast"):
|
||||
time.sleep(0.01)
|
||||
|
||||
with benchmark.timer("slow"):
|
||||
time.sleep(0.2)
|
||||
|
||||
benchmark.analyze()
|
||||
|
||||
# Should have recommendation about bottleneck
|
||||
assert len(benchmark.result.recommendations) > 0
|
||||
assert any("bottleneck" in r.lower() for r in benchmark.result.recommendations)
|
||||
|
||||
def test_analyze_high_memory(self):
|
||||
"""Test high memory usage detection."""
|
||||
benchmark = Benchmark("test")
|
||||
|
||||
# Simulate high memory usage
|
||||
usage = MemoryUsage(
|
||||
operation="allocate",
|
||||
before_mb=100.0,
|
||||
after_mb=1200.0,
|
||||
peak_mb=1500.0,
|
||||
allocated_mb=1100.0
|
||||
)
|
||||
benchmark.result.add_memory(usage)
|
||||
|
||||
benchmark.analyze()
|
||||
|
||||
# Should have recommendation about memory
|
||||
assert len(benchmark.result.recommendations) > 0
|
||||
assert any("memory" in r.lower() for r in benchmark.result.recommendations)
|
||||
|
||||
|
||||
class TestBenchmarkRunner:
|
||||
"""Test BenchmarkRunner class."""
|
||||
|
||||
def test_runner_initialization(self, tmp_path):
|
||||
"""Test runner initialization."""
|
||||
runner = BenchmarkRunner(output_dir=tmp_path)
|
||||
|
||||
assert runner.output_dir == tmp_path
|
||||
assert runner.output_dir.exists()
|
||||
|
||||
def test_run_benchmark(self, tmp_path):
|
||||
"""Test running single benchmark."""
|
||||
runner = BenchmarkRunner(output_dir=tmp_path)
|
||||
|
||||
def test_benchmark(bench):
|
||||
with bench.timer("operation"):
|
||||
time.sleep(0.05)
|
||||
|
||||
report = runner.run("test", test_benchmark, save=True)
|
||||
|
||||
assert isinstance(report, BenchmarkReport)
|
||||
assert report.name == "test"
|
||||
assert len(report.timings) == 1
|
||||
|
||||
# Check file was saved
|
||||
saved_files = list(tmp_path.glob("test_*.json"))
|
||||
assert len(saved_files) == 1
|
||||
|
||||
def test_run_benchmark_no_save(self, tmp_path):
|
||||
"""Test running benchmark without saving."""
|
||||
runner = BenchmarkRunner(output_dir=tmp_path)
|
||||
|
||||
def test_benchmark(bench):
|
||||
with bench.timer("operation"):
|
||||
time.sleep(0.05)
|
||||
|
||||
report = runner.run("test", test_benchmark, save=False)
|
||||
|
||||
assert isinstance(report, BenchmarkReport)
|
||||
|
||||
# No files should be saved
|
||||
saved_files = list(tmp_path.glob("*.json"))
|
||||
assert len(saved_files) == 0
|
||||
|
||||
def test_run_suite(self, tmp_path):
|
||||
"""Test running benchmark suite."""
|
||||
runner = BenchmarkRunner(output_dir=tmp_path)
|
||||
|
||||
def bench1(bench):
|
||||
with bench.timer("op1"):
|
||||
time.sleep(0.02)
|
||||
|
||||
def bench2(bench):
|
||||
with bench.timer("op2"):
|
||||
time.sleep(0.03)
|
||||
|
||||
reports = runner.run_suite({
|
||||
"test1": bench1,
|
||||
"test2": bench2
|
||||
})
|
||||
|
||||
assert len(reports) == 2
|
||||
assert "test1" in reports
|
||||
assert "test2" in reports
|
||||
|
||||
# Check both files saved
|
||||
saved_files = list(tmp_path.glob("*.json"))
|
||||
assert len(saved_files) == 2
|
||||
|
||||
def test_compare_benchmarks(self, tmp_path):
|
||||
"""Test comparing benchmarks."""
|
||||
runner = BenchmarkRunner(output_dir=tmp_path)
|
||||
|
||||
# Create baseline
|
||||
def baseline_bench(bench):
|
||||
with bench.timer("operation"):
|
||||
time.sleep(0.1)
|
||||
|
||||
baseline_report = runner.run("baseline", baseline_bench, save=True)
|
||||
baseline_path = list(tmp_path.glob("baseline_*.json"))[0]
|
||||
|
||||
# Create faster version
|
||||
def improved_bench(bench):
|
||||
with bench.timer("operation"):
|
||||
time.sleep(0.05)
|
||||
|
||||
improved_report = runner.run("improved", improved_bench, save=True)
|
||||
improved_path = list(tmp_path.glob("improved_*.json"))[0]
|
||||
|
||||
# Compare
|
||||
from skill_seekers.benchmark.models import ComparisonReport
|
||||
comparison = runner.compare(baseline_path, improved_path)
|
||||
|
||||
assert isinstance(comparison, ComparisonReport)
|
||||
assert comparison.speedup_factor > 1.0
|
||||
assert len(comparison.improvements) > 0
|
||||
|
||||
def test_list_benchmarks(self, tmp_path):
|
||||
"""Test listing benchmarks."""
|
||||
runner = BenchmarkRunner(output_dir=tmp_path)
|
||||
|
||||
# Create some benchmarks
|
||||
def test_bench(bench):
|
||||
with bench.timer("op"):
|
||||
time.sleep(0.02)
|
||||
|
||||
runner.run("bench1", test_bench, save=True)
|
||||
runner.run("bench2", test_bench, save=True)
|
||||
|
||||
benchmarks = runner.list_benchmarks()
|
||||
|
||||
assert len(benchmarks) == 2
|
||||
assert all("name" in b for b in benchmarks)
|
||||
assert all("duration" in b for b in benchmarks)
|
||||
|
||||
def test_get_latest(self, tmp_path):
|
||||
"""Test getting latest benchmark."""
|
||||
runner = BenchmarkRunner(output_dir=tmp_path)
|
||||
|
||||
def test_bench(bench):
|
||||
with bench.timer("op"):
|
||||
time.sleep(0.02)
|
||||
|
||||
# Run same benchmark twice
|
||||
runner.run("test", test_bench, save=True)
|
||||
time.sleep(0.1) # Ensure different timestamps
|
||||
runner.run("test", test_bench, save=True)
|
||||
|
||||
latest = runner.get_latest("test")
|
||||
|
||||
assert latest is not None
|
||||
assert "test_" in latest.name
|
||||
|
||||
def test_get_latest_not_found(self, tmp_path):
|
||||
"""Test getting latest when benchmark doesn't exist."""
|
||||
runner = BenchmarkRunner(output_dir=tmp_path)
|
||||
|
||||
latest = runner.get_latest("nonexistent")
|
||||
|
||||
assert latest is None
|
||||
|
||||
def test_cleanup_old(self, tmp_path):
|
||||
"""Test cleaning up old benchmarks."""
|
||||
import os
|
||||
runner = BenchmarkRunner(output_dir=tmp_path)
|
||||
|
||||
# Create 10 benchmark files with different timestamps
|
||||
base_time = time.time()
|
||||
for i in range(10):
|
||||
filename = f"test_{i:08d}.json"
|
||||
file_path = tmp_path / filename
|
||||
|
||||
# Create minimal valid report
|
||||
report_data = {
|
||||
"name": "test",
|
||||
"started_at": datetime.utcnow().isoformat(),
|
||||
"finished_at": datetime.utcnow().isoformat(),
|
||||
"total_duration": 1.0,
|
||||
"timings": [],
|
||||
"memory": [],
|
||||
"metrics": [],
|
||||
"system_info": {},
|
||||
"recommendations": []
|
||||
}
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(report_data, f)
|
||||
|
||||
# Set different modification times
|
||||
mtime = base_time - (10 - i) * 60 # Older files have older mtimes
|
||||
os.utime(file_path, (mtime, mtime))
|
||||
|
||||
# Verify we have 10 files
|
||||
assert len(list(tmp_path.glob("test_*.json"))) == 10
|
||||
|
||||
# Keep only latest 3
|
||||
runner.cleanup_old(keep_latest=3)
|
||||
|
||||
remaining = list(tmp_path.glob("test_*.json"))
|
||||
assert len(remaining) == 3
|
||||
|
||||
# Verify we kept the newest files (7, 8, 9)
|
||||
remaining_names = {f.stem for f in remaining}
|
||||
assert "test_00000007" in remaining_names or "test_00000008" in remaining_names
|
||||
|
||||
|
||||
class TestBenchmarkModels:
|
||||
"""Test benchmark model classes."""
|
||||
|
||||
def test_timing_result_model(self):
|
||||
"""Test TimingResult model."""
|
||||
timing = TimingResult(
|
||||
operation="test",
|
||||
duration=1.5,
|
||||
iterations=10,
|
||||
avg_duration=0.15
|
||||
)
|
||||
|
||||
assert timing.operation == "test"
|
||||
assert timing.duration == 1.5
|
||||
assert timing.iterations == 10
|
||||
assert timing.avg_duration == 0.15
|
||||
|
||||
def test_memory_usage_model(self):
|
||||
"""Test MemoryUsage model."""
|
||||
usage = MemoryUsage(
|
||||
operation="allocate",
|
||||
before_mb=100.0,
|
||||
after_mb=200.0,
|
||||
peak_mb=250.0,
|
||||
allocated_mb=100.0
|
||||
)
|
||||
|
||||
assert usage.operation == "allocate"
|
||||
assert usage.allocated_mb == 100.0
|
||||
assert usage.peak_mb == 250.0
|
||||
|
||||
def test_metric_model(self):
|
||||
"""Test Metric model."""
|
||||
metric = Metric(
|
||||
name="throughput",
|
||||
value=125.5,
|
||||
unit="ops/sec"
|
||||
)
|
||||
|
||||
assert metric.name == "throughput"
|
||||
assert metric.value == 125.5
|
||||
assert metric.unit == "ops/sec"
|
||||
assert isinstance(metric.timestamp, datetime)
|
||||
|
||||
def test_benchmark_report_summary(self):
|
||||
"""Test BenchmarkReport summary property."""
|
||||
report = BenchmarkReport(
|
||||
name="test",
|
||||
started_at=datetime.utcnow(),
|
||||
finished_at=datetime.utcnow(),
|
||||
total_duration=5.0,
|
||||
timings=[
|
||||
TimingResult(
|
||||
operation="op1",
|
||||
duration=2.0,
|
||||
iterations=1,
|
||||
avg_duration=2.0
|
||||
)
|
||||
],
|
||||
memory=[
|
||||
MemoryUsage(
|
||||
operation="op1",
|
||||
before_mb=100.0,
|
||||
after_mb=200.0,
|
||||
peak_mb=250.0,
|
||||
allocated_mb=100.0
|
||||
)
|
||||
],
|
||||
metrics=[],
|
||||
system_info={},
|
||||
recommendations=[]
|
||||
)
|
||||
|
||||
summary = report.summary
|
||||
|
||||
assert "test" in summary
|
||||
assert "5.00s" in summary
|
||||
assert "250.0MB" in summary
|
||||
|
||||
def test_comparison_report_has_regressions(self):
|
||||
"""Test ComparisonReport has_regressions property."""
|
||||
from skill_seekers.benchmark.models import ComparisonReport
|
||||
|
||||
baseline = BenchmarkReport(
|
||||
name="baseline",
|
||||
started_at=datetime.utcnow(),
|
||||
finished_at=datetime.utcnow(),
|
||||
total_duration=5.0,
|
||||
timings=[],
|
||||
memory=[],
|
||||
metrics=[],
|
||||
system_info={},
|
||||
recommendations=[]
|
||||
)
|
||||
|
||||
current = BenchmarkReport(
|
||||
name="current",
|
||||
started_at=datetime.utcnow(),
|
||||
finished_at=datetime.utcnow(),
|
||||
total_duration=10.0,
|
||||
timings=[],
|
||||
memory=[],
|
||||
metrics=[],
|
||||
system_info={},
|
||||
recommendations=[]
|
||||
)
|
||||
|
||||
comparison = ComparisonReport(
|
||||
name="test",
|
||||
baseline=baseline,
|
||||
current=current,
|
||||
improvements=[],
|
||||
regressions=["Slower performance"],
|
||||
speedup_factor=0.5,
|
||||
memory_change_mb=0.0
|
||||
)
|
||||
|
||||
assert comparison.has_regressions is True
|
||||
|
||||
def test_comparison_report_overall_improvement(self):
|
||||
"""Test ComparisonReport overall_improvement property."""
|
||||
from skill_seekers.benchmark.models import ComparisonReport
|
||||
|
||||
baseline = BenchmarkReport(
|
||||
name="baseline",
|
||||
started_at=datetime.utcnow(),
|
||||
finished_at=datetime.utcnow(),
|
||||
total_duration=10.0,
|
||||
timings=[],
|
||||
memory=[],
|
||||
metrics=[],
|
||||
system_info={},
|
||||
recommendations=[]
|
||||
)
|
||||
|
||||
current = BenchmarkReport(
|
||||
name="current",
|
||||
started_at=datetime.utcnow(),
|
||||
finished_at=datetime.utcnow(),
|
||||
total_duration=5.0,
|
||||
timings=[],
|
||||
memory=[],
|
||||
metrics=[],
|
||||
system_info={},
|
||||
recommendations=[]
|
||||
)
|
||||
|
||||
comparison = ComparisonReport(
|
||||
name="test",
|
||||
baseline=baseline,
|
||||
current=current,
|
||||
improvements=[],
|
||||
regressions=[],
|
||||
speedup_factor=2.0,
|
||||
memory_change_mb=0.0
|
||||
)
|
||||
|
||||
improvement = comparison.overall_improvement
|
||||
|
||||
assert "100.0% faster" in improvement
|
||||
assert "✅" in improvement
|
||||
457
tests/test_cloud_storage.py
Normal file
457
tests/test_cloud_storage.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
Tests for cloud storage adaptors.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from skill_seekers.cli.storage import (
|
||||
get_storage_adaptor,
|
||||
BaseStorageAdaptor,
|
||||
S3StorageAdaptor,
|
||||
GCSStorageAdaptor,
|
||||
AzureStorageAdaptor,
|
||||
StorageObject,
|
||||
)
|
||||
|
||||
|
||||
# ========================================
|
||||
# Factory Tests
|
||||
# ========================================
|
||||
|
||||
def test_get_storage_adaptor_s3():
|
||||
"""Test S3 adaptor factory."""
|
||||
with patch('skill_seekers.cli.storage.s3_storage.boto3'):
|
||||
adaptor = get_storage_adaptor('s3', bucket='test-bucket')
|
||||
assert isinstance(adaptor, S3StorageAdaptor)
|
||||
|
||||
|
||||
def test_get_storage_adaptor_gcs():
|
||||
"""Test GCS adaptor factory."""
|
||||
with patch('skill_seekers.cli.storage.gcs_storage.storage'):
|
||||
adaptor = get_storage_adaptor('gcs', bucket='test-bucket')
|
||||
assert isinstance(adaptor, GCSStorageAdaptor)
|
||||
|
||||
|
||||
def test_get_storage_adaptor_azure():
|
||||
"""Test Azure adaptor factory."""
|
||||
with patch('skill_seekers.cli.storage.azure_storage.BlobServiceClient'):
|
||||
adaptor = get_storage_adaptor(
|
||||
'azure',
|
||||
container='test-container',
|
||||
connection_string='DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key'
|
||||
)
|
||||
assert isinstance(adaptor, AzureStorageAdaptor)
|
||||
|
||||
|
||||
def test_get_storage_adaptor_invalid_provider():
|
||||
"""Test invalid provider raises error."""
|
||||
with pytest.raises(ValueError, match="Unsupported storage provider"):
|
||||
get_storage_adaptor('invalid', bucket='test')
|
||||
|
||||
|
||||
# ========================================
|
||||
# S3 Storage Tests
|
||||
# ========================================
|
||||
|
||||
@patch('skill_seekers.cli.storage.s3_storage.boto3')
|
||||
def test_s3_upload_file(mock_boto3):
|
||||
"""Test S3 file upload."""
|
||||
# Setup mocks
|
||||
mock_client = Mock()
|
||||
mock_boto3.client.return_value = mock_client
|
||||
mock_boto3.resource.return_value = Mock()
|
||||
|
||||
adaptor = S3StorageAdaptor(bucket='test-bucket')
|
||||
|
||||
# Create temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
||||
tmp_file.write(b'test content')
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
try:
|
||||
# Test upload
|
||||
result = adaptor.upload_file(tmp_path, 'test.txt')
|
||||
|
||||
assert result == 's3://test-bucket/test.txt'
|
||||
mock_client.upload_file.assert_called_once()
|
||||
finally:
|
||||
Path(tmp_path).unlink()
|
||||
|
||||
|
||||
@patch('skill_seekers.cli.storage.s3_storage.boto3')
|
||||
def test_s3_download_file(mock_boto3):
|
||||
"""Test S3 file download."""
|
||||
# Setup mocks
|
||||
mock_client = Mock()
|
||||
mock_boto3.client.return_value = mock_client
|
||||
mock_boto3.resource.return_value = Mock()
|
||||
|
||||
adaptor = S3StorageAdaptor(bucket='test-bucket')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
local_path = os.path.join(tmp_dir, 'downloaded.txt')
|
||||
|
||||
# Test download
|
||||
adaptor.download_file('test.txt', local_path)
|
||||
|
||||
mock_client.download_file.assert_called_once_with(
|
||||
'test-bucket', 'test.txt', local_path
|
||||
)
|
||||
|
||||
|
||||
@patch('skill_seekers.cli.storage.s3_storage.boto3')
|
||||
def test_s3_list_files(mock_boto3):
|
||||
"""Test S3 file listing."""
|
||||
# Setup mocks
|
||||
mock_client = Mock()
|
||||
mock_paginator = Mock()
|
||||
mock_page_iterator = [
|
||||
{
|
||||
'Contents': [
|
||||
{
|
||||
'Key': 'file1.txt',
|
||||
'Size': 100,
|
||||
'LastModified': Mock(isoformat=lambda: '2024-01-01T00:00:00'),
|
||||
'ETag': '"abc123"'
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_paginator.paginate.return_value = mock_page_iterator
|
||||
mock_client.get_paginator.return_value = mock_paginator
|
||||
mock_boto3.client.return_value = mock_client
|
||||
mock_boto3.resource.return_value = Mock()
|
||||
|
||||
adaptor = S3StorageAdaptor(bucket='test-bucket')
|
||||
|
||||
# Test list
|
||||
files = adaptor.list_files('prefix/')
|
||||
|
||||
assert len(files) == 1
|
||||
assert files[0].key == 'file1.txt'
|
||||
assert files[0].size == 100
|
||||
assert files[0].etag == 'abc123'
|
||||
|
||||
|
||||
@patch('skill_seekers.cli.storage.s3_storage.boto3')
|
||||
def test_s3_file_exists(mock_boto3):
|
||||
"""Test S3 file existence check."""
|
||||
# Setup mocks
|
||||
mock_client = Mock()
|
||||
mock_client.head_object.return_value = {}
|
||||
mock_boto3.client.return_value = mock_client
|
||||
mock_boto3.resource.return_value = Mock()
|
||||
|
||||
adaptor = S3StorageAdaptor(bucket='test-bucket')
|
||||
|
||||
# Test exists
|
||||
assert adaptor.file_exists('test.txt') is True
|
||||
|
||||
|
||||
@patch('skill_seekers.cli.storage.s3_storage.boto3')
|
||||
def test_s3_get_file_url(mock_boto3):
|
||||
"""Test S3 presigned URL generation."""
|
||||
# Setup mocks
|
||||
mock_client = Mock()
|
||||
mock_client.generate_presigned_url.return_value = 'https://s3.amazonaws.com/signed-url'
|
||||
mock_boto3.client.return_value = mock_client
|
||||
mock_boto3.resource.return_value = Mock()
|
||||
|
||||
adaptor = S3StorageAdaptor(bucket='test-bucket')
|
||||
|
||||
# Test URL generation
|
||||
url = adaptor.get_file_url('test.txt', expires_in=7200)
|
||||
|
||||
assert url == 'https://s3.amazonaws.com/signed-url'
|
||||
mock_client.generate_presigned_url.assert_called_once()
|
||||
|
||||
|
||||
# ========================================
|
||||
# GCS Storage Tests
|
||||
# ========================================
|
||||
|
||||
@patch('skill_seekers.cli.storage.gcs_storage.storage')
|
||||
def test_gcs_upload_file(mock_storage):
|
||||
"""Test GCS file upload."""
|
||||
# Setup mocks
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
mock_storage.Client.return_value = mock_client
|
||||
|
||||
adaptor = GCSStorageAdaptor(bucket='test-bucket')
|
||||
|
||||
# Create temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
||||
tmp_file.write(b'test content')
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
try:
|
||||
# Test upload
|
||||
result = adaptor.upload_file(tmp_path, 'test.txt')
|
||||
|
||||
assert result == 'gs://test-bucket/test.txt'
|
||||
mock_blob.upload_from_filename.assert_called_once()
|
||||
finally:
|
||||
Path(tmp_path).unlink()
|
||||
|
||||
|
||||
@patch('skill_seekers.cli.storage.gcs_storage.storage')
|
||||
def test_gcs_download_file(mock_storage):
|
||||
"""Test GCS file download."""
|
||||
# Setup mocks
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
mock_storage.Client.return_value = mock_client
|
||||
|
||||
adaptor = GCSStorageAdaptor(bucket='test-bucket')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
local_path = os.path.join(tmp_dir, 'downloaded.txt')
|
||||
|
||||
# Test download
|
||||
adaptor.download_file('test.txt', local_path)
|
||||
|
||||
mock_blob.download_to_filename.assert_called_once()
|
||||
|
||||
|
||||
@patch('skill_seekers.cli.storage.gcs_storage.storage')
|
||||
def test_gcs_list_files(mock_storage):
|
||||
"""Test GCS file listing."""
|
||||
# Setup mocks
|
||||
mock_client = Mock()
|
||||
mock_blob = Mock()
|
||||
mock_blob.name = 'file1.txt'
|
||||
mock_blob.size = 100
|
||||
mock_blob.updated = Mock(isoformat=lambda: '2024-01-01T00:00:00')
|
||||
mock_blob.etag = 'abc123'
|
||||
mock_blob.metadata = {}
|
||||
|
||||
mock_client.list_blobs.return_value = [mock_blob]
|
||||
mock_storage.Client.return_value = mock_client
|
||||
mock_client.bucket.return_value = Mock()
|
||||
|
||||
adaptor = GCSStorageAdaptor(bucket='test-bucket')
|
||||
|
||||
# Test list
|
||||
files = adaptor.list_files('prefix/')
|
||||
|
||||
assert len(files) == 1
|
||||
assert files[0].key == 'file1.txt'
|
||||
assert files[0].size == 100
|
||||
|
||||
|
||||
# ========================================
|
||||
# Azure Storage Tests
|
||||
# ========================================
|
||||
|
||||
@patch('skill_seekers.cli.storage.azure_storage.BlobServiceClient')
|
||||
def test_azure_upload_file(mock_blob_service):
|
||||
"""Test Azure file upload."""
|
||||
# Setup mocks
|
||||
mock_service_client = Mock()
|
||||
mock_container_client = Mock()
|
||||
mock_blob_client = Mock()
|
||||
|
||||
mock_service_client.get_container_client.return_value = mock_container_client
|
||||
mock_container_client.get_blob_client.return_value = mock_blob_client
|
||||
mock_blob_service.from_connection_string.return_value = mock_service_client
|
||||
|
||||
connection_string = 'DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key'
|
||||
adaptor = AzureStorageAdaptor(container='test-container', connection_string=connection_string)
|
||||
|
||||
# Create temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
||||
tmp_file.write(b'test content')
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
try:
|
||||
# Test upload
|
||||
result = adaptor.upload_file(tmp_path, 'test.txt')
|
||||
|
||||
assert 'test.blob.core.windows.net' in result
|
||||
mock_blob_client.upload_blob.assert_called_once()
|
||||
finally:
|
||||
Path(tmp_path).unlink()
|
||||
|
||||
|
||||
@patch('skill_seekers.cli.storage.azure_storage.BlobServiceClient')
|
||||
def test_azure_download_file(mock_blob_service):
|
||||
"""Test Azure file download."""
|
||||
# Setup mocks
|
||||
mock_service_client = Mock()
|
||||
mock_container_client = Mock()
|
||||
mock_blob_client = Mock()
|
||||
mock_download_stream = Mock()
|
||||
mock_download_stream.readall.return_value = b'test content'
|
||||
|
||||
mock_service_client.get_container_client.return_value = mock_container_client
|
||||
mock_container_client.get_blob_client.return_value = mock_blob_client
|
||||
mock_blob_client.download_blob.return_value = mock_download_stream
|
||||
mock_blob_service.from_connection_string.return_value = mock_service_client
|
||||
|
||||
connection_string = 'DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key'
|
||||
adaptor = AzureStorageAdaptor(container='test-container', connection_string=connection_string)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
local_path = os.path.join(tmp_dir, 'downloaded.txt')
|
||||
|
||||
# Test download
|
||||
adaptor.download_file('test.txt', local_path)
|
||||
|
||||
assert Path(local_path).exists()
|
||||
assert Path(local_path).read_bytes() == b'test content'
|
||||
|
||||
|
||||
@patch('skill_seekers.cli.storage.azure_storage.BlobServiceClient')
|
||||
def test_azure_list_files(mock_blob_service):
|
||||
"""Test Azure file listing."""
|
||||
# Setup mocks
|
||||
mock_service_client = Mock()
|
||||
mock_container_client = Mock()
|
||||
mock_blob = Mock()
|
||||
mock_blob.name = 'file1.txt'
|
||||
mock_blob.size = 100
|
||||
mock_blob.last_modified = Mock(isoformat=lambda: '2024-01-01T00:00:00')
|
||||
mock_blob.etag = 'abc123'
|
||||
mock_blob.metadata = {}
|
||||
|
||||
mock_container_client.list_blobs.return_value = [mock_blob]
|
||||
mock_service_client.get_container_client.return_value = mock_container_client
|
||||
mock_blob_service.from_connection_string.return_value = mock_service_client
|
||||
|
||||
connection_string = 'DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key'
|
||||
adaptor = AzureStorageAdaptor(container='test-container', connection_string=connection_string)
|
||||
|
||||
# Test list
|
||||
files = adaptor.list_files('prefix/')
|
||||
|
||||
assert len(files) == 1
|
||||
assert files[0].key == 'file1.txt'
|
||||
assert files[0].size == 100
|
||||
|
||||
|
||||
# ========================================
|
||||
# Base Adaptor Tests
|
||||
# ========================================
|
||||
|
||||
def test_storage_object():
|
||||
"""Test StorageObject dataclass."""
|
||||
obj = StorageObject(
|
||||
key='test.txt',
|
||||
size=100,
|
||||
last_modified='2024-01-01T00:00:00',
|
||||
etag='abc123',
|
||||
metadata={'key': 'value'}
|
||||
)
|
||||
|
||||
assert obj.key == 'test.txt'
|
||||
assert obj.size == 100
|
||||
assert obj.metadata == {'key': 'value'}
|
||||
|
||||
|
||||
def test_base_adaptor_abstract():
|
||||
"""Test that BaseStorageAdaptor cannot be instantiated."""
|
||||
with pytest.raises(TypeError):
|
||||
BaseStorageAdaptor(bucket='test')
|
||||
|
||||
|
||||
# ========================================
|
||||
# Integration-style Tests
|
||||
# ========================================
|
||||
|
||||
@patch('skill_seekers.cli.storage.s3_storage.boto3')
|
||||
def test_upload_directory(mock_boto3):
|
||||
"""Test directory upload."""
|
||||
# Setup mocks
|
||||
mock_client = Mock()
|
||||
mock_boto3.client.return_value = mock_client
|
||||
mock_boto3.resource.return_value = Mock()
|
||||
|
||||
adaptor = S3StorageAdaptor(bucket='test-bucket')
|
||||
|
||||
# Create temporary directory with files
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
(Path(tmp_dir) / 'file1.txt').write_text('content1')
|
||||
(Path(tmp_dir) / 'file2.txt').write_text('content2')
|
||||
(Path(tmp_dir) / 'subdir').mkdir()
|
||||
(Path(tmp_dir) / 'subdir' / 'file3.txt').write_text('content3')
|
||||
|
||||
# Test upload directory
|
||||
uploaded_files = adaptor.upload_directory(tmp_dir, 'skills/')
|
||||
|
||||
assert len(uploaded_files) == 3
|
||||
assert mock_client.upload_file.call_count == 3
|
||||
|
||||
|
||||
@patch('skill_seekers.cli.storage.s3_storage.boto3')
|
||||
def test_download_directory(mock_boto3):
|
||||
"""Test directory download."""
|
||||
# Setup mocks
|
||||
mock_client = Mock()
|
||||
mock_paginator = Mock()
|
||||
mock_page_iterator = [
|
||||
{
|
||||
'Contents': [
|
||||
{
|
||||
'Key': 'skills/file1.txt',
|
||||
'Size': 100,
|
||||
'LastModified': Mock(isoformat=lambda: '2024-01-01T00:00:00'),
|
||||
'ETag': '"abc"'
|
||||
},
|
||||
{
|
||||
'Key': 'skills/file2.txt',
|
||||
'Size': 200,
|
||||
'LastModified': Mock(isoformat=lambda: '2024-01-01T00:00:00'),
|
||||
'ETag': '"def"'
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
mock_paginator.paginate.return_value = mock_page_iterator
|
||||
mock_client.get_paginator.return_value = mock_paginator
|
||||
mock_boto3.client.return_value = mock_client
|
||||
mock_boto3.resource.return_value = Mock()
|
||||
|
||||
adaptor = S3StorageAdaptor(bucket='test-bucket')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Test download directory
|
||||
downloaded_files = adaptor.download_directory('skills/', tmp_dir)
|
||||
|
||||
assert len(downloaded_files) == 2
|
||||
assert mock_client.download_file.call_count == 2
|
||||
|
||||
|
||||
def test_missing_dependencies():
|
||||
"""Test graceful handling of missing dependencies."""
|
||||
# Test S3 without boto3
|
||||
with patch.dict('sys.modules', {'boto3': None}):
|
||||
with pytest.raises(ImportError, match="boto3 is required"):
|
||||
from skill_seekers.cli.storage.s3_storage import S3StorageAdaptor
|
||||
S3StorageAdaptor(bucket='test')
|
||||
|
||||
# Test GCS without google-cloud-storage
|
||||
with patch.dict('sys.modules', {'google.cloud.storage': None}):
|
||||
with pytest.raises(ImportError, match="google-cloud-storage is required"):
|
||||
from skill_seekers.cli.storage.gcs_storage import GCSStorageAdaptor
|
||||
GCSStorageAdaptor(bucket='test')
|
||||
|
||||
# Test Azure without azure-storage-blob
|
||||
with patch.dict('sys.modules', {'azure.storage.blob': None}):
|
||||
with pytest.raises(ImportError, match="azure-storage-blob is required"):
|
||||
from skill_seekers.cli.storage.azure_storage import AzureStorageAdaptor
|
||||
AzureStorageAdaptor(container='test', connection_string='test')
|
||||
369
tests/test_embedding.py
Normal file
369
tests/test_embedding.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""
|
||||
Tests for embedding generation system.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from skill_seekers.embedding.models import (
|
||||
EmbeddingRequest,
|
||||
BatchEmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
BatchEmbeddingResponse,
|
||||
HealthResponse,
|
||||
ModelInfo,
|
||||
)
|
||||
from skill_seekers.embedding.generator import EmbeddingGenerator
|
||||
from skill_seekers.embedding.cache import EmbeddingCache
|
||||
|
||||
|
||||
# ========================================
|
||||
# Cache Tests
|
||||
# ========================================
|
||||
|
||||
def test_cache_init():
|
||||
"""Test cache initialization."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
assert cache.size() == 0
|
||||
|
||||
|
||||
def test_cache_set_get():
|
||||
"""Test cache set and get."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
cache.set("hash123", embedding, "test-model")
|
||||
|
||||
retrieved = cache.get("hash123")
|
||||
assert retrieved == embedding
|
||||
|
||||
|
||||
def test_cache_has():
|
||||
"""Test cache has method."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
cache.set("hash123", embedding, "test-model")
|
||||
|
||||
assert cache.has("hash123") is True
|
||||
assert cache.has("nonexistent") is False
|
||||
|
||||
|
||||
def test_cache_delete():
|
||||
"""Test cache deletion."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
cache.set("hash123", embedding, "test-model")
|
||||
|
||||
assert cache.has("hash123") is True
|
||||
|
||||
cache.delete("hash123")
|
||||
|
||||
assert cache.has("hash123") is False
|
||||
|
||||
|
||||
def test_cache_clear():
|
||||
"""Test cache clearing."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
|
||||
cache.set("hash1", [0.1], "model1")
|
||||
cache.set("hash2", [0.2], "model2")
|
||||
cache.set("hash3", [0.3], "model1")
|
||||
|
||||
assert cache.size() == 3
|
||||
|
||||
# Clear specific model
|
||||
deleted = cache.clear(model="model1")
|
||||
assert deleted == 2
|
||||
assert cache.size() == 1
|
||||
|
||||
# Clear all
|
||||
deleted = cache.clear()
|
||||
assert deleted == 1
|
||||
assert cache.size() == 0
|
||||
|
||||
|
||||
def test_cache_stats():
|
||||
"""Test cache statistics."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
|
||||
cache.set("hash1", [0.1], "model1")
|
||||
cache.set("hash2", [0.2], "model2")
|
||||
cache.set("hash3", [0.3], "model1")
|
||||
|
||||
stats = cache.stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["by_model"]["model1"] == 2
|
||||
assert stats["by_model"]["model2"] == 1
|
||||
|
||||
|
||||
def test_cache_context_manager():
|
||||
"""Test cache as context manager."""
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
with EmbeddingCache(tmp_path) as cache:
|
||||
cache.set("hash1", [0.1], "model1")
|
||||
assert cache.size() == 1
|
||||
|
||||
# Verify database file exists
|
||||
assert Path(tmp_path).exists()
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
# ========================================
|
||||
# Generator Tests
|
||||
# ========================================
|
||||
|
||||
def test_generator_init():
|
||||
"""Test generator initialization."""
|
||||
generator = EmbeddingGenerator()
|
||||
assert generator is not None
|
||||
|
||||
|
||||
def test_generator_list_models():
|
||||
"""Test listing models."""
|
||||
generator = EmbeddingGenerator()
|
||||
models = generator.list_models()
|
||||
|
||||
assert len(models) > 0
|
||||
assert all("name" in m for m in models)
|
||||
assert all("provider" in m for m in models)
|
||||
assert all("dimensions" in m for m in models)
|
||||
|
||||
|
||||
def test_generator_get_model_info():
|
||||
"""Test getting model info."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
info = generator.get_model_info("text-embedding-3-small")
|
||||
|
||||
assert info["provider"] == "openai"
|
||||
assert info["dimensions"] == 1536
|
||||
assert info["max_tokens"] == 8191
|
||||
|
||||
|
||||
def test_generator_get_model_info_invalid():
|
||||
"""Test getting model info for invalid model."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown model"):
|
||||
generator.get_model_info("nonexistent-model")
|
||||
|
||||
|
||||
def test_generator_compute_hash():
|
||||
"""Test hash computation."""
|
||||
hash1 = EmbeddingGenerator.compute_hash("text1", "model1")
|
||||
hash2 = EmbeddingGenerator.compute_hash("text1", "model1")
|
||||
hash3 = EmbeddingGenerator.compute_hash("text2", "model1")
|
||||
hash4 = EmbeddingGenerator.compute_hash("text1", "model2")
|
||||
|
||||
# Same text+model = same hash
|
||||
assert hash1 == hash2
|
||||
|
||||
# Different text = different hash
|
||||
assert hash1 != hash3
|
||||
|
||||
# Different model = different hash
|
||||
assert hash1 != hash4
|
||||
|
||||
|
||||
@patch('skill_seekers.embedding.generator.SENTENCE_TRANSFORMERS_AVAILABLE', False)
|
||||
def test_generator_sentence_transformers_not_available():
|
||||
"""Test sentence-transformers not available."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
with pytest.raises(ImportError, match="sentence-transformers is required"):
|
||||
generator.generate("test", model="all-MiniLM-L6-v2")
|
||||
|
||||
|
||||
@patch('skill_seekers.embedding.generator.OPENAI_AVAILABLE', False)
|
||||
def test_generator_openai_not_available():
|
||||
"""Test OpenAI not available."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
with pytest.raises(ImportError, match="OpenAI is required"):
|
||||
generator.generate("test", model="text-embedding-3-small")
|
||||
|
||||
|
||||
@patch('skill_seekers.embedding.generator.VOYAGE_AVAILABLE', False)
|
||||
def test_generator_voyage_not_available():
|
||||
"""Test Voyage AI not available."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
with pytest.raises(ImportError, match="voyageai is required"):
|
||||
generator.generate("test", model="voyage-3")
|
||||
|
||||
|
||||
def test_generator_voyage_model_info():
|
||||
"""Test getting Voyage AI model info."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
info = generator.get_model_info("voyage-3")
|
||||
|
||||
assert info["provider"] == "voyage"
|
||||
assert info["dimensions"] == 1024
|
||||
assert info["max_tokens"] == 32000
|
||||
|
||||
|
||||
def test_generator_voyage_large_2_model_info():
|
||||
"""Test getting Voyage Large 2 model info."""
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
info = generator.get_model_info("voyage-large-2")
|
||||
|
||||
assert info["provider"] == "voyage"
|
||||
assert info["dimensions"] == 1536
|
||||
assert info["cost_per_million"] == 0.12
|
||||
|
||||
|
||||
# ========================================
|
||||
# Model Tests
|
||||
# ========================================
|
||||
|
||||
def test_embedding_request():
|
||||
"""Test EmbeddingRequest model."""
|
||||
request = EmbeddingRequest(
|
||||
text="Hello world",
|
||||
model="text-embedding-3-small",
|
||||
normalize=True
|
||||
)
|
||||
|
||||
assert request.text == "Hello world"
|
||||
assert request.model == "text-embedding-3-small"
|
||||
assert request.normalize is True
|
||||
|
||||
|
||||
def test_batch_embedding_request():
|
||||
"""Test BatchEmbeddingRequest model."""
|
||||
request = BatchEmbeddingRequest(
|
||||
texts=["text1", "text2", "text3"],
|
||||
model="text-embedding-3-small",
|
||||
batch_size=32
|
||||
)
|
||||
|
||||
assert len(request.texts) == 3
|
||||
assert request.batch_size == 32
|
||||
|
||||
|
||||
def test_embedding_response():
|
||||
"""Test EmbeddingResponse model."""
|
||||
response = EmbeddingResponse(
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
model="test-model",
|
||||
dimensions=3,
|
||||
cached=False
|
||||
)
|
||||
|
||||
assert len(response.embedding) == 3
|
||||
assert response.dimensions == 3
|
||||
assert response.cached is False
|
||||
|
||||
|
||||
def test_batch_embedding_response():
|
||||
"""Test BatchEmbeddingResponse model."""
|
||||
response = BatchEmbeddingResponse(
|
||||
embeddings=[[0.1, 0.2], [0.3, 0.4]],
|
||||
model="test-model",
|
||||
dimensions=2,
|
||||
count=2,
|
||||
cached_count=1
|
||||
)
|
||||
|
||||
assert len(response.embeddings) == 2
|
||||
assert response.count == 2
|
||||
assert response.cached_count == 1
|
||||
|
||||
|
||||
def test_health_response():
|
||||
"""Test HealthResponse model."""
|
||||
response = HealthResponse(
|
||||
status="ok",
|
||||
version="1.0.0",
|
||||
models=["model1", "model2"],
|
||||
cache_enabled=True,
|
||||
cache_size=100
|
||||
)
|
||||
|
||||
assert response.status == "ok"
|
||||
assert len(response.models) == 2
|
||||
assert response.cache_size == 100
|
||||
|
||||
|
||||
def test_model_info():
|
||||
"""Test ModelInfo model."""
|
||||
info = ModelInfo(
|
||||
name="test-model",
|
||||
provider="openai",
|
||||
dimensions=1536,
|
||||
max_tokens=8191,
|
||||
cost_per_million=0.02
|
||||
)
|
||||
|
||||
assert info.name == "test-model"
|
||||
assert info.provider == "openai"
|
||||
assert info.cost_per_million == 0.02
|
||||
|
||||
|
||||
# ========================================
|
||||
# Integration Tests
|
||||
# ========================================
|
||||
|
||||
def test_cache_batch_operations():
|
||||
"""Test cache batch operations."""
|
||||
cache = EmbeddingCache(":memory:")
|
||||
|
||||
# Set multiple embeddings
|
||||
cache.set("hash1", [0.1, 0.2], "model1")
|
||||
cache.set("hash2", [0.3, 0.4], "model1")
|
||||
cache.set("hash3", [0.5, 0.6], "model1")
|
||||
|
||||
# Get batch
|
||||
embeddings, cached_flags = cache.get_batch(["hash1", "hash2", "hash999", "hash3"])
|
||||
|
||||
assert len(embeddings) == 4
|
||||
assert embeddings[0] == [0.1, 0.2]
|
||||
assert embeddings[1] == [0.3, 0.4]
|
||||
assert embeddings[2] is None # Cache miss
|
||||
assert embeddings[3] == [0.5, 0.6]
|
||||
|
||||
assert cached_flags == [True, True, False, True]
|
||||
|
||||
|
||||
def test_generator_normalize():
|
||||
"""Test embedding normalization."""
|
||||
import numpy as np
|
||||
|
||||
embedding = [3.0, 4.0] # Length 5
|
||||
normalized = EmbeddingGenerator._normalize(embedding)
|
||||
|
||||
# Check unit length
|
||||
length = np.linalg.norm(normalized)
|
||||
assert abs(length - 1.0) < 1e-6
|
||||
|
||||
|
||||
def test_cache_persistence():
|
||||
"""Test cache persistence to file."""
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# Create cache and add data
|
||||
cache1 = EmbeddingCache(tmp_path)
|
||||
cache1.set("hash1", [0.1, 0.2, 0.3], "model1")
|
||||
cache1.close()
|
||||
|
||||
# Reopen cache and verify data persists
|
||||
cache2 = EmbeddingCache(tmp_path)
|
||||
retrieved = cache2.get("hash1")
|
||||
assert retrieved == [0.1, 0.2, 0.3]
|
||||
cache2.close()
|
||||
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
259
tests/test_mcp_vector_dbs.py
Normal file
259
tests/test_mcp_vector_dbs.py
Normal file
@@ -0,0 +1,259 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for MCP vector database tools.
|
||||
|
||||
Validates the 4 new vector database export tools:
|
||||
- export_to_weaviate
|
||||
- export_to_chroma
|
||||
- export_to_faiss
|
||||
- export_to_qdrant
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import tempfile
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
from skill_seekers.mcp.tools.vector_db_tools import (
|
||||
export_to_weaviate_impl,
|
||||
export_to_chroma_impl,
|
||||
export_to_faiss_impl,
|
||||
export_to_qdrant_impl,
|
||||
)
|
||||
|
||||
|
||||
def run_async(coro):
|
||||
"""Helper to run async functions in sync tests."""
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_skill_dir():
|
||||
"""Create a test skill directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
skill_dir = Path(tmpdir) / "test_skill"
|
||||
skill_dir.mkdir()
|
||||
|
||||
# Create SKILL.md
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"# Test Skill\n\n"
|
||||
"This is a test skill for vector database export.\n\n"
|
||||
"## Getting Started\n\n"
|
||||
"Quick start guide content.\n"
|
||||
)
|
||||
|
||||
# Create references
|
||||
refs_dir = skill_dir / "references"
|
||||
refs_dir.mkdir()
|
||||
|
||||
(refs_dir / "api.md").write_text("# API Reference\n\nAPI documentation.")
|
||||
(refs_dir / "examples.md").write_text("# Examples\n\nCode examples.")
|
||||
|
||||
yield skill_dir
|
||||
|
||||
|
||||
def test_export_to_weaviate(test_skill_dir):
|
||||
"""Test Weaviate export tool."""
|
||||
output_dir = test_skill_dir.parent
|
||||
|
||||
args = {
|
||||
"skill_dir": str(test_skill_dir),
|
||||
"output_dir": str(output_dir),
|
||||
}
|
||||
|
||||
result = run_async(export_to_weaviate_impl(args))
|
||||
|
||||
# Check result structure
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert hasattr(result[0], "text")
|
||||
|
||||
# Check result content
|
||||
text = result[0].text
|
||||
assert "✅ Weaviate Export Complete!" in text
|
||||
assert "test_skill-weaviate.json" in text
|
||||
assert "weaviate.Client" in text # Check for usage instructions
|
||||
|
||||
|
||||
def test_export_to_chroma(test_skill_dir):
|
||||
"""Test Chroma export tool."""
|
||||
output_dir = test_skill_dir.parent
|
||||
|
||||
args = {
|
||||
"skill_dir": str(test_skill_dir),
|
||||
"output_dir": str(output_dir),
|
||||
}
|
||||
|
||||
result = run_async(export_to_chroma_impl(args))
|
||||
|
||||
# Check result structure
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert hasattr(result[0], "text")
|
||||
|
||||
# Check result content
|
||||
text = result[0].text
|
||||
assert "✅ Chroma Export Complete!" in text
|
||||
assert "test_skill-chroma.json" in text
|
||||
assert "chromadb" in text # Check for usage instructions
|
||||
|
||||
|
||||
def test_export_to_faiss(test_skill_dir):
|
||||
"""Test FAISS export tool."""
|
||||
output_dir = test_skill_dir.parent
|
||||
|
||||
args = {
|
||||
"skill_dir": str(test_skill_dir),
|
||||
"output_dir": str(output_dir),
|
||||
}
|
||||
|
||||
result = run_async(export_to_faiss_impl(args))
|
||||
|
||||
# Check result structure
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert hasattr(result[0], "text")
|
||||
|
||||
# Check result content
|
||||
text = result[0].text
|
||||
assert "✅ FAISS Export Complete!" in text
|
||||
assert "test_skill-faiss.json" in text
|
||||
assert "import faiss" in text # Check for usage instructions
|
||||
|
||||
|
||||
def test_export_to_qdrant(test_skill_dir):
|
||||
"""Test Qdrant export tool."""
|
||||
output_dir = test_skill_dir.parent
|
||||
|
||||
args = {
|
||||
"skill_dir": str(test_skill_dir),
|
||||
"output_dir": str(output_dir),
|
||||
}
|
||||
|
||||
result = run_async(export_to_qdrant_impl(args))
|
||||
|
||||
# Check result structure
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert hasattr(result[0], "text")
|
||||
|
||||
# Check result content
|
||||
text = result[0].text
|
||||
assert "✅ Qdrant Export Complete!" in text
|
||||
assert "test_skill-qdrant.json" in text
|
||||
assert "QdrantClient" in text # Check for usage instructions
|
||||
|
||||
|
||||
def test_export_with_default_output_dir(test_skill_dir):
|
||||
"""Test export with default output directory."""
|
||||
args = {"skill_dir": str(test_skill_dir)}
|
||||
|
||||
# Should use parent directory as default
|
||||
result = run_async(export_to_weaviate_impl(args))
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
text = result[0].text
|
||||
assert "✅" in text
|
||||
assert "test_skill-weaviate.json" in text
|
||||
|
||||
|
||||
def test_export_missing_skill_dir():
|
||||
"""Test export with missing skill directory."""
|
||||
args = {"skill_dir": "/nonexistent/path"}
|
||||
|
||||
result = run_async(export_to_weaviate_impl(args))
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
text = result[0].text
|
||||
assert "❌ Error" in text
|
||||
assert "not found" in text
|
||||
|
||||
|
||||
def test_all_exports_create_files(test_skill_dir):
|
||||
"""Test that all export tools create output files."""
|
||||
output_dir = test_skill_dir.parent
|
||||
|
||||
# Test all 4 exports
|
||||
exports = [
|
||||
("weaviate", export_to_weaviate_impl),
|
||||
("chroma", export_to_chroma_impl),
|
||||
("faiss", export_to_faiss_impl),
|
||||
("qdrant", export_to_qdrant_impl),
|
||||
]
|
||||
|
||||
for target, export_func in exports:
|
||||
args = {
|
||||
"skill_dir": str(test_skill_dir),
|
||||
"output_dir": str(output_dir),
|
||||
}
|
||||
|
||||
result = run_async(export_func(args))
|
||||
|
||||
# Check success
|
||||
assert isinstance(result, list)
|
||||
text = result[0].text
|
||||
assert "✅" in text
|
||||
|
||||
# Check file exists
|
||||
expected_file = output_dir / f"test_skill-{target}.json"
|
||||
assert expected_file.exists(), f"{target} export file not created"
|
||||
|
||||
# Check file content is valid JSON
|
||||
with open(expected_file) as f:
|
||||
data = json.load(f)
|
||||
assert isinstance(data, dict)
|
||||
|
||||
|
||||
def test_export_output_includes_instructions():
|
||||
"""Test that export outputs include usage instructions."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
skill_dir = Path(tmpdir) / "test_skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text("# Test")
|
||||
|
||||
# Create minimal references
|
||||
refs_dir = skill_dir / "references"
|
||||
refs_dir.mkdir()
|
||||
(refs_dir / "guide.md").write_text("# Guide")
|
||||
|
||||
args = {"skill_dir": str(skill_dir)}
|
||||
|
||||
# Test Weaviate includes instructions
|
||||
result = run_async(export_to_weaviate_impl(args))
|
||||
text = result[0].text
|
||||
assert "Next Steps:" in text
|
||||
assert "Upload to Weaviate:" in text
|
||||
assert "Query with hybrid search:" in text
|
||||
assert "Resources:" in text
|
||||
|
||||
# Test Chroma includes instructions
|
||||
result = run_async(export_to_chroma_impl(args))
|
||||
text = result[0].text
|
||||
assert "Next Steps:" in text
|
||||
assert "Load into Chroma:" in text
|
||||
assert "Query the collection:" in text
|
||||
|
||||
# Test FAISS includes instructions
|
||||
result = run_async(export_to_faiss_impl(args))
|
||||
text = result[0].text
|
||||
assert "Next Steps:" in text
|
||||
assert "Build FAISS index:" in text
|
||||
assert "Search:" in text
|
||||
|
||||
# Test Qdrant includes instructions
|
||||
result = run_async(export_to_qdrant_impl(args))
|
||||
text = result[0].text
|
||||
assert "Next Steps:" in text
|
||||
assert "Upload to Qdrant:" in text
|
||||
assert "Search with filters:" in text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user