fix: Enforce min_chunk_size in RAG chunker
- Filter out chunks smaller than min_chunk_size (default 100 tokens) - Exception: Keep all chunks if entire document is smaller than target size - All 15 tests passing (100% pass rate) Fixes edge case where very small chunks (e.g., 'Short.' = 6 chars) were being created despite min_chunk_size=100 setting. Test: pytest tests/test_rag_chunker.py -v
This commit is contained in:
41
src/skill_seekers/benchmark/__init__.py
Normal file
41
src/skill_seekers/benchmark/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Performance benchmarking suite for Skill Seekers.
|
||||
|
||||
Measures and analyzes performance of:
|
||||
- Documentation scraping
|
||||
- Embedding generation
|
||||
- Storage operations
|
||||
- End-to-end workflows
|
||||
|
||||
Features:
|
||||
- Accurate timing measurements
|
||||
- Memory usage tracking
|
||||
- CPU profiling
|
||||
- Comparison reports
|
||||
- Optimization recommendations
|
||||
|
||||
Usage:
|
||||
from skill_seekers.benchmark import Benchmark
|
||||
|
||||
# Create benchmark
|
||||
benchmark = Benchmark("scraping-test")
|
||||
|
||||
# Time operations
|
||||
with benchmark.timer("scrape_pages"):
|
||||
scrape_docs(config)
|
||||
|
||||
# Generate report
|
||||
report = benchmark.report()
|
||||
"""
|
||||
|
||||
from .framework import Benchmark, BenchmarkResult
|
||||
from .runner import BenchmarkRunner
|
||||
from .models import BenchmarkReport, Metric
|
||||
|
||||
__all__ = [
|
||||
'Benchmark',
|
||||
'BenchmarkResult',
|
||||
'BenchmarkRunner',
|
||||
'BenchmarkReport',
|
||||
'Metric',
|
||||
]
|
||||
373
src/skill_seekers/benchmark/framework.py
Normal file
373
src/skill_seekers/benchmark/framework.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""
|
||||
Core benchmarking framework.
|
||||
"""
|
||||
|
||||
import time
|
||||
import psutil
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
from pathlib import Path
|
||||
|
||||
from .models import (
|
||||
Metric,
|
||||
TimingResult,
|
||||
MemoryUsage,
|
||||
BenchmarkReport
|
||||
)
|
||||
|
||||
|
||||
class BenchmarkResult:
|
||||
"""
|
||||
Stores benchmark results during execution.
|
||||
|
||||
Examples:
|
||||
result = BenchmarkResult("test-benchmark")
|
||||
result.add_timing(...)
|
||||
result.add_memory(...)
|
||||
report = result.to_report()
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
"""
|
||||
Initialize result collector.
|
||||
|
||||
Args:
|
||||
name: Benchmark name
|
||||
"""
|
||||
self.name = name
|
||||
self.started_at = datetime.utcnow()
|
||||
self.finished_at: Optional[datetime] = None
|
||||
|
||||
self.timings: List[TimingResult] = []
|
||||
self.memory: List[MemoryUsage] = []
|
||||
self.metrics: List[Metric] = []
|
||||
self.system_info: Dict[str, Any] = {}
|
||||
self.recommendations: List[str] = []
|
||||
|
||||
def add_timing(self, result: TimingResult):
|
||||
"""Add timing result."""
|
||||
self.timings.append(result)
|
||||
|
||||
def add_memory(self, usage: MemoryUsage):
|
||||
"""Add memory usage."""
|
||||
self.memory.append(usage)
|
||||
|
||||
def add_metric(self, metric: Metric):
|
||||
"""Add custom metric."""
|
||||
self.metrics.append(metric)
|
||||
|
||||
def add_recommendation(self, text: str):
|
||||
"""Add optimization recommendation."""
|
||||
self.recommendations.append(text)
|
||||
|
||||
def set_system_info(self):
|
||||
"""Collect system information."""
|
||||
self.system_info = {
|
||||
"cpu_count": psutil.cpu_count(),
|
||||
"cpu_freq_mhz": psutil.cpu_freq().current if psutil.cpu_freq() else 0,
|
||||
"memory_total_gb": psutil.virtual_memory().total / (1024**3),
|
||||
"memory_available_gb": psutil.virtual_memory().available / (1024**3),
|
||||
"python_version": f"{psutil.version_info[0]}.{psutil.version_info[1]}",
|
||||
}
|
||||
|
||||
def to_report(self) -> BenchmarkReport:
|
||||
"""
|
||||
Generate final report.
|
||||
|
||||
Returns:
|
||||
Complete benchmark report
|
||||
"""
|
||||
if not self.finished_at:
|
||||
self.finished_at = datetime.utcnow()
|
||||
|
||||
if not self.system_info:
|
||||
self.set_system_info()
|
||||
|
||||
total_duration = (self.finished_at - self.started_at).total_seconds()
|
||||
|
||||
return BenchmarkReport(
|
||||
name=self.name,
|
||||
started_at=self.started_at,
|
||||
finished_at=self.finished_at,
|
||||
total_duration=total_duration,
|
||||
timings=self.timings,
|
||||
memory=self.memory,
|
||||
metrics=self.metrics,
|
||||
system_info=self.system_info,
|
||||
recommendations=self.recommendations
|
||||
)
|
||||
|
||||
|
||||
class Benchmark:
|
||||
"""
|
||||
Main benchmarking interface.
|
||||
|
||||
Provides context managers and decorators for timing and profiling.
|
||||
|
||||
Examples:
|
||||
# Create benchmark
|
||||
benchmark = Benchmark("scraping-test")
|
||||
|
||||
# Time operations
|
||||
with benchmark.timer("scrape_pages"):
|
||||
scrape_docs(config)
|
||||
|
||||
# Track memory
|
||||
with benchmark.memory("process_data"):
|
||||
process_large_dataset()
|
||||
|
||||
# Generate report
|
||||
report = benchmark.report()
|
||||
print(report.summary)
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
"""
|
||||
Initialize benchmark.
|
||||
|
||||
Args:
|
||||
name: Benchmark name
|
||||
"""
|
||||
self.name = name
|
||||
self.result = BenchmarkResult(name)
|
||||
|
||||
@contextmanager
|
||||
def timer(self, operation: str, iterations: int = 1):
|
||||
"""
|
||||
Time an operation.
|
||||
|
||||
Args:
|
||||
operation: Operation name
|
||||
iterations: Number of iterations (for averaging)
|
||||
|
||||
Yields:
|
||||
None
|
||||
|
||||
Examples:
|
||||
with benchmark.timer("load_pages"):
|
||||
load_all_pages()
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
duration = time.perf_counter() - start
|
||||
|
||||
timing = TimingResult(
|
||||
operation=operation,
|
||||
duration=duration,
|
||||
iterations=iterations,
|
||||
avg_duration=duration / iterations if iterations > 1 else duration
|
||||
)
|
||||
|
||||
self.result.add_timing(timing)
|
||||
|
||||
@contextmanager
|
||||
def memory(self, operation: str):
|
||||
"""
|
||||
Track memory usage.
|
||||
|
||||
Args:
|
||||
operation: Operation name
|
||||
|
||||
Yields:
|
||||
None
|
||||
|
||||
Examples:
|
||||
with benchmark.memory("embed_docs"):
|
||||
generate_embeddings()
|
||||
"""
|
||||
process = psutil.Process()
|
||||
|
||||
# Get memory before
|
||||
mem_before = process.memory_info().rss / (1024**2) # MB
|
||||
|
||||
# Track peak during operation
|
||||
peak_memory = mem_before
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Get memory after
|
||||
mem_after = process.memory_info().rss / (1024**2) # MB
|
||||
peak_memory = max(peak_memory, mem_after)
|
||||
|
||||
usage = MemoryUsage(
|
||||
operation=operation,
|
||||
before_mb=mem_before,
|
||||
after_mb=mem_after,
|
||||
peak_mb=peak_memory,
|
||||
allocated_mb=mem_after - mem_before
|
||||
)
|
||||
|
||||
self.result.add_memory(usage)
|
||||
|
||||
def measure(
|
||||
self,
|
||||
func: Callable,
|
||||
*args,
|
||||
operation: Optional[str] = None,
|
||||
track_memory: bool = False,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Measure function execution.
|
||||
|
||||
Args:
|
||||
func: Function to measure
|
||||
*args: Positional arguments
|
||||
operation: Operation name (defaults to func.__name__)
|
||||
track_memory: Whether to track memory
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
Function result
|
||||
|
||||
Examples:
|
||||
result = benchmark.measure(
|
||||
scrape_all,
|
||||
config,
|
||||
operation="scrape_docs",
|
||||
track_memory=True
|
||||
)
|
||||
"""
|
||||
op_name = operation or func.__name__
|
||||
|
||||
if track_memory:
|
||||
with self.memory(op_name):
|
||||
with self.timer(op_name):
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
with self.timer(op_name):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def timed(self, operation: Optional[str] = None, track_memory: bool = False):
|
||||
"""
|
||||
Decorator for timing functions.
|
||||
|
||||
Args:
|
||||
operation: Operation name (defaults to func.__name__)
|
||||
track_memory: Whether to track memory
|
||||
|
||||
Returns:
|
||||
Decorated function
|
||||
|
||||
Examples:
|
||||
@benchmark.timed("load_config")
|
||||
def load_config(path):
|
||||
return json.load(open(path))
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return self.measure(
|
||||
func,
|
||||
*args,
|
||||
operation=operation,
|
||||
track_memory=track_memory,
|
||||
**kwargs
|
||||
)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
def metric(self, name: str, value: float, unit: str):
|
||||
"""
|
||||
Record custom metric.
|
||||
|
||||
Args:
|
||||
name: Metric name
|
||||
value: Metric value
|
||||
unit: Unit of measurement
|
||||
|
||||
Examples:
|
||||
benchmark.metric("pages_per_sec", 12.5, "pages/sec")
|
||||
"""
|
||||
metric = Metric(
|
||||
name=name,
|
||||
value=value,
|
||||
unit=unit
|
||||
)
|
||||
self.result.add_metric(metric)
|
||||
|
||||
def recommend(self, text: str):
|
||||
"""
|
||||
Add optimization recommendation.
|
||||
|
||||
Args:
|
||||
text: Recommendation text
|
||||
|
||||
Examples:
|
||||
if duration > 5.0:
|
||||
benchmark.recommend("Consider caching results")
|
||||
"""
|
||||
self.result.add_recommendation(text)
|
||||
|
||||
def report(self) -> BenchmarkReport:
|
||||
"""
|
||||
Generate final report.
|
||||
|
||||
Returns:
|
||||
Complete benchmark report
|
||||
"""
|
||||
return self.result.to_report()
|
||||
|
||||
def save(self, path: Path):
|
||||
"""
|
||||
Save report to JSON file.
|
||||
|
||||
Args:
|
||||
path: Output file path
|
||||
|
||||
Examples:
|
||||
benchmark.save(Path("benchmarks/scraping_v2.json"))
|
||||
"""
|
||||
report = self.report()
|
||||
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write(report.model_dump_json(indent=2))
|
||||
|
||||
def analyze(self):
|
||||
"""
|
||||
Analyze results and generate recommendations.
|
||||
|
||||
Automatically called by report(), but can be called manually.
|
||||
"""
|
||||
# Analyze timing bottlenecks
|
||||
if self.result.timings:
|
||||
sorted_timings = sorted(
|
||||
self.result.timings,
|
||||
key=lambda t: t.duration,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
slowest = sorted_timings[0]
|
||||
total_time = sum(t.duration for t in self.result.timings)
|
||||
|
||||
if slowest.duration > total_time * 0.5:
|
||||
self.recommend(
|
||||
f"Bottleneck: '{slowest.operation}' takes "
|
||||
f"{slowest.duration:.1f}s ({slowest.duration/total_time*100:.0f}% of total)"
|
||||
)
|
||||
|
||||
# Analyze memory usage
|
||||
if self.result.memory:
|
||||
peak = max(m.peak_mb for m in self.result.memory)
|
||||
|
||||
if peak > 1000: # >1GB
|
||||
self.recommend(
|
||||
f"High memory usage: {peak:.0f}MB peak. "
|
||||
"Consider processing in batches."
|
||||
)
|
||||
|
||||
# Check for memory leaks
|
||||
for usage in self.result.memory:
|
||||
if usage.allocated_mb > 100: # >100MB allocated
|
||||
self.recommend(
|
||||
f"Large allocation in '{usage.operation}': "
|
||||
f"{usage.allocated_mb:.0f}MB. Check for memory leaks."
|
||||
)
|
||||
117
src/skill_seekers/benchmark/models.py
Normal file
117
src/skill_seekers/benchmark/models.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Pydantic models for benchmarking.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Metric(BaseModel):
|
||||
"""Single performance metric."""
|
||||
|
||||
name: str = Field(..., description="Metric name")
|
||||
value: float = Field(..., description="Metric value")
|
||||
unit: str = Field(..., description="Unit (seconds, bytes, pages/sec, etc.)")
|
||||
timestamp: datetime = Field(
|
||||
default_factory=datetime.utcnow,
|
||||
description="When metric was recorded"
|
||||
)
|
||||
|
||||
|
||||
class TimingResult(BaseModel):
|
||||
"""Result of a timed operation."""
|
||||
|
||||
operation: str = Field(..., description="Operation name")
|
||||
duration: float = Field(..., description="Duration in seconds")
|
||||
iterations: int = Field(default=1, description="Number of iterations")
|
||||
avg_duration: float = Field(..., description="Average duration per iteration")
|
||||
min_duration: Optional[float] = Field(None, description="Minimum duration")
|
||||
max_duration: Optional[float] = Field(None, description="Maximum duration")
|
||||
|
||||
|
||||
class MemoryUsage(BaseModel):
|
||||
"""Memory usage information."""
|
||||
|
||||
operation: str = Field(..., description="Operation name")
|
||||
before_mb: float = Field(..., description="Memory before operation (MB)")
|
||||
after_mb: float = Field(..., description="Memory after operation (MB)")
|
||||
peak_mb: float = Field(..., description="Peak memory during operation (MB)")
|
||||
allocated_mb: float = Field(..., description="Memory allocated (MB)")
|
||||
|
||||
|
||||
class BenchmarkReport(BaseModel):
|
||||
"""Complete benchmark report."""
|
||||
|
||||
name: str = Field(..., description="Benchmark name")
|
||||
started_at: datetime = Field(..., description="Start time")
|
||||
finished_at: datetime = Field(..., description="Finish time")
|
||||
total_duration: float = Field(..., description="Total duration in seconds")
|
||||
|
||||
timings: List[TimingResult] = Field(
|
||||
default_factory=list,
|
||||
description="Timing results"
|
||||
)
|
||||
memory: List[MemoryUsage] = Field(
|
||||
default_factory=list,
|
||||
description="Memory usage results"
|
||||
)
|
||||
metrics: List[Metric] = Field(
|
||||
default_factory=list,
|
||||
description="Additional metrics"
|
||||
)
|
||||
|
||||
system_info: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="System information"
|
||||
)
|
||||
recommendations: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Optimization recommendations"
|
||||
)
|
||||
|
||||
@property
|
||||
def summary(self) -> str:
|
||||
"""Generate summary string."""
|
||||
lines = [
|
||||
f"Benchmark: {self.name}",
|
||||
f"Duration: {self.total_duration:.2f}s",
|
||||
f"Operations: {len(self.timings)}",
|
||||
f"Peak Memory: {max([m.peak_mb for m in self.memory], default=0):.1f}MB",
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class ComparisonReport(BaseModel):
|
||||
"""Comparison between two benchmarks."""
|
||||
|
||||
name: str = Field(..., description="Comparison name")
|
||||
baseline: BenchmarkReport = Field(..., description="Baseline benchmark")
|
||||
current: BenchmarkReport = Field(..., description="Current benchmark")
|
||||
|
||||
improvements: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Performance improvements"
|
||||
)
|
||||
regressions: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Performance regressions"
|
||||
)
|
||||
|
||||
speedup_factor: float = Field(..., description="Overall speedup factor")
|
||||
memory_change_mb: float = Field(..., description="Memory usage change (MB)")
|
||||
|
||||
@property
|
||||
def has_regressions(self) -> bool:
|
||||
"""Check if there are any regressions."""
|
||||
return len(self.regressions) > 0
|
||||
|
||||
@property
|
||||
def overall_improvement(self) -> str:
|
||||
"""Overall improvement summary."""
|
||||
if self.speedup_factor > 1.1:
|
||||
return f"✅ {(self.speedup_factor - 1) * 100:.1f}% faster"
|
||||
elif self.speedup_factor < 0.9:
|
||||
return f"❌ {(1 - self.speedup_factor) * 100:.1f}% slower"
|
||||
else:
|
||||
return "⚠️ Similar performance"
|
||||
321
src/skill_seekers/benchmark/runner.py
Normal file
321
src/skill_seekers/benchmark/runner.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Benchmark execution and orchestration.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
from datetime import datetime
|
||||
|
||||
from .framework import Benchmark
|
||||
from .models import BenchmarkReport, ComparisonReport
|
||||
|
||||
|
||||
class BenchmarkRunner:
|
||||
"""
|
||||
Run and compare benchmarks.
|
||||
|
||||
Examples:
|
||||
runner = BenchmarkRunner()
|
||||
|
||||
# Run single benchmark
|
||||
report = runner.run("scraping-v2", scraping_benchmark)
|
||||
|
||||
# Compare with baseline
|
||||
comparison = runner.compare(
|
||||
baseline_path="benchmarks/v1.json",
|
||||
current_path="benchmarks/v2.json"
|
||||
)
|
||||
|
||||
# Run suite
|
||||
reports = runner.run_suite({
|
||||
"scraping": scraping_benchmark,
|
||||
"embedding": embedding_benchmark,
|
||||
})
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize runner.
|
||||
|
||||
Args:
|
||||
output_dir: Directory for benchmark results
|
||||
"""
|
||||
self.output_dir = output_dir or Path("benchmarks")
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def run(
|
||||
self,
|
||||
name: str,
|
||||
benchmark_func: Callable[[Benchmark], None],
|
||||
save: bool = True
|
||||
) -> BenchmarkReport:
|
||||
"""
|
||||
Run single benchmark.
|
||||
|
||||
Args:
|
||||
name: Benchmark name
|
||||
benchmark_func: Function that performs benchmark
|
||||
save: Whether to save results
|
||||
|
||||
Returns:
|
||||
Benchmark report
|
||||
|
||||
Examples:
|
||||
def scraping_benchmark(bench):
|
||||
with bench.timer("scrape"):
|
||||
scrape_docs(config)
|
||||
|
||||
report = runner.run("scraping-v2", scraping_benchmark)
|
||||
"""
|
||||
benchmark = Benchmark(name)
|
||||
|
||||
# Run benchmark
|
||||
benchmark_func(benchmark)
|
||||
|
||||
# Generate report
|
||||
report = benchmark.report()
|
||||
|
||||
# Save if requested
|
||||
if save:
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{name}_{timestamp}.json"
|
||||
path = self.output_dir / filename
|
||||
|
||||
with open(path, 'w') as f:
|
||||
f.write(report.model_dump_json(indent=2))
|
||||
|
||||
print(f"📊 Saved benchmark: {path}")
|
||||
|
||||
return report
|
||||
|
||||
def run_suite(
|
||||
self,
|
||||
benchmarks: Dict[str, Callable[[Benchmark], None]],
|
||||
save: bool = True
|
||||
) -> Dict[str, BenchmarkReport]:
|
||||
"""
|
||||
Run multiple benchmarks.
|
||||
|
||||
Args:
|
||||
benchmarks: Dict of name -> benchmark function
|
||||
save: Whether to save results
|
||||
|
||||
Returns:
|
||||
Dict of name -> report
|
||||
|
||||
Examples:
|
||||
reports = runner.run_suite({
|
||||
"scraping": scraping_benchmark,
|
||||
"embedding": embedding_benchmark,
|
||||
})
|
||||
"""
|
||||
reports = {}
|
||||
|
||||
for name, func in benchmarks.items():
|
||||
print(f"\n🏃 Running benchmark: {name}")
|
||||
report = self.run(name, func, save=save)
|
||||
reports[name] = report
|
||||
|
||||
print(report.summary)
|
||||
|
||||
return reports
|
||||
|
||||
def compare(
|
||||
self,
|
||||
baseline_path: Path,
|
||||
current_path: Path
|
||||
) -> ComparisonReport:
|
||||
"""
|
||||
Compare two benchmark reports.
|
||||
|
||||
Args:
|
||||
baseline_path: Path to baseline report
|
||||
current_path: Path to current report
|
||||
|
||||
Returns:
|
||||
Comparison report
|
||||
|
||||
Examples:
|
||||
comparison = runner.compare(
|
||||
baseline_path=Path("benchmarks/v1.json"),
|
||||
current_path=Path("benchmarks/v2.json")
|
||||
)
|
||||
|
||||
print(comparison.overall_improvement)
|
||||
"""
|
||||
# Load reports
|
||||
with open(baseline_path) as f:
|
||||
baseline_data = json.load(f)
|
||||
baseline = BenchmarkReport(**baseline_data)
|
||||
|
||||
with open(current_path) as f:
|
||||
current_data = json.load(f)
|
||||
current = BenchmarkReport(**current_data)
|
||||
|
||||
# Calculate changes
|
||||
improvements = []
|
||||
regressions = []
|
||||
|
||||
# Compare timings
|
||||
baseline_timings = {t.operation: t for t in baseline.timings}
|
||||
current_timings = {t.operation: t for t in current.timings}
|
||||
|
||||
for op, current_timing in current_timings.items():
|
||||
if op in baseline_timings:
|
||||
baseline_timing = baseline_timings[op]
|
||||
|
||||
speedup = baseline_timing.duration / current_timing.duration
|
||||
|
||||
if speedup > 1.1: # >10% faster
|
||||
improvements.append(
|
||||
f"'{op}': {(speedup - 1) * 100:.1f}% faster "
|
||||
f"({baseline_timing.duration:.2f}s → {current_timing.duration:.2f}s)"
|
||||
)
|
||||
elif speedup < 0.9: # >10% slower
|
||||
regressions.append(
|
||||
f"'{op}': {(1 - speedup) * 100:.1f}% slower "
|
||||
f"({baseline_timing.duration:.2f}s → {current_timing.duration:.2f}s)"
|
||||
)
|
||||
|
||||
# Compare memory
|
||||
baseline_memory = {m.operation: m for m in baseline.memory}
|
||||
current_memory = {m.operation: m for m in current.memory}
|
||||
|
||||
for op, current_mem in current_memory.items():
|
||||
if op in baseline_memory:
|
||||
baseline_mem = baseline_memory[op]
|
||||
|
||||
mem_change = current_mem.peak_mb - baseline_mem.peak_mb
|
||||
|
||||
if mem_change < -10: # >10MB reduction
|
||||
improvements.append(
|
||||
f"'{op}' memory: {abs(mem_change):.0f}MB reduction "
|
||||
f"({baseline_mem.peak_mb:.0f}MB → {current_mem.peak_mb:.0f}MB)"
|
||||
)
|
||||
elif mem_change > 10: # >10MB increase
|
||||
regressions.append(
|
||||
f"'{op}' memory: {mem_change:.0f}MB increase "
|
||||
f"({baseline_mem.peak_mb:.0f}MB → {current_mem.peak_mb:.0f}MB)"
|
||||
)
|
||||
|
||||
# Overall speedup
|
||||
speedup_factor = baseline.total_duration / current.total_duration
|
||||
|
||||
# Memory change
|
||||
baseline_peak = max([m.peak_mb for m in baseline.memory], default=0)
|
||||
current_peak = max([m.peak_mb for m in current.memory], default=0)
|
||||
memory_change_mb = current_peak - baseline_peak
|
||||
|
||||
return ComparisonReport(
|
||||
name=f"{baseline.name} vs {current.name}",
|
||||
baseline=baseline,
|
||||
current=current,
|
||||
improvements=improvements,
|
||||
regressions=regressions,
|
||||
speedup_factor=speedup_factor,
|
||||
memory_change_mb=memory_change_mb
|
||||
)
|
||||
|
||||
def list_benchmarks(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List saved benchmarks.
|
||||
|
||||
Returns:
|
||||
List of benchmark metadata
|
||||
|
||||
Examples:
|
||||
benchmarks = runner.list_benchmarks()
|
||||
for bench in benchmarks:
|
||||
print(f"{bench['name']}: {bench['duration']:.1f}s")
|
||||
"""
|
||||
benchmarks = []
|
||||
|
||||
for path in self.output_dir.glob("*.json"):
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
benchmarks.append({
|
||||
"name": data["name"],
|
||||
"path": str(path),
|
||||
"started_at": data["started_at"],
|
||||
"duration": data["total_duration"],
|
||||
"operations": len(data.get("timings", []))
|
||||
})
|
||||
except Exception:
|
||||
# Skip invalid files
|
||||
continue
|
||||
|
||||
# Sort by date
|
||||
benchmarks.sort(key=lambda b: b["started_at"], reverse=True)
|
||||
|
||||
return benchmarks
|
||||
|
||||
def get_latest(self, name: str) -> Optional[Path]:
|
||||
"""
|
||||
Get path to latest benchmark with given name.
|
||||
|
||||
Args:
|
||||
name: Benchmark name
|
||||
|
||||
Returns:
|
||||
Path to latest report, or None
|
||||
|
||||
Examples:
|
||||
latest = runner.get_latest("scraping-v2")
|
||||
if latest:
|
||||
with open(latest) as f:
|
||||
report = BenchmarkReport(**json.load(f))
|
||||
"""
|
||||
matching = []
|
||||
|
||||
for path in self.output_dir.glob(f"{name}_*.json"):
|
||||
matching.append(path)
|
||||
|
||||
if not matching:
|
||||
return None
|
||||
|
||||
# Sort by modification time
|
||||
matching.sort(key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
return matching[0]
|
||||
|
||||
def cleanup_old(self, keep_latest: int = 5):
|
||||
"""
|
||||
Remove old benchmark files.
|
||||
|
||||
Args:
|
||||
keep_latest: Number of latest benchmarks to keep per name
|
||||
|
||||
Examples:
|
||||
runner.cleanup_old(keep_latest=3)
|
||||
"""
|
||||
# Group by benchmark name
|
||||
by_name: Dict[str, List[Path]] = {}
|
||||
|
||||
for path in self.output_dir.glob("*.json"):
|
||||
# Extract name from filename (name_timestamp.json)
|
||||
parts = path.stem.split("_")
|
||||
if len(parts) >= 2:
|
||||
name = "_".join(parts[:-1]) # Everything except timestamp
|
||||
|
||||
if name not in by_name:
|
||||
by_name[name] = []
|
||||
|
||||
by_name[name].append(path)
|
||||
|
||||
# Keep only latest N for each name
|
||||
removed = 0
|
||||
|
||||
for name, paths in by_name.items():
|
||||
# Sort by modification time
|
||||
paths.sort(key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
# Remove old ones
|
||||
for path in paths[keep_latest:]:
|
||||
path.unlink()
|
||||
removed += 1
|
||||
|
||||
if removed > 0:
|
||||
print(f"🗑️ Removed {removed} old benchmark(s)")
|
||||
312
src/skill_seekers/cli/benchmark_cli.py
Normal file
312
src/skill_seekers/cli/benchmark_cli.py
Normal file
@@ -0,0 +1,312 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Performance benchmarking CLI.
|
||||
|
||||
Measure and analyze performance of scraping, embedding, and storage operations.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from ..benchmark import Benchmark, BenchmarkRunner, BenchmarkReport
|
||||
|
||||
|
||||
def run_command(args):
|
||||
"""Run benchmark from config."""
|
||||
runner = BenchmarkRunner(output_dir=Path(args.output_dir))
|
||||
|
||||
# Load benchmark config
|
||||
with open(args.config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
benchmark_type = config.get("type", "custom")
|
||||
|
||||
if benchmark_type == "scraping":
|
||||
run_scraping_benchmark(runner, config)
|
||||
elif benchmark_type == "embedding":
|
||||
run_embedding_benchmark(runner, config)
|
||||
elif benchmark_type == "storage":
|
||||
run_storage_benchmark(runner, config)
|
||||
else:
|
||||
print(f"❌ Unknown benchmark type: {benchmark_type}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def run_scraping_benchmark(runner, config):
|
||||
"""Run scraping benchmark."""
|
||||
from .doc_scraper import scrape_all, build_skill
|
||||
|
||||
def benchmark_func(bench: Benchmark):
|
||||
scrape_config_path = config.get("scrape_config")
|
||||
|
||||
# Time scraping
|
||||
with bench.timer("scrape_docs"):
|
||||
with bench.memory("scrape_docs"):
|
||||
pages = scrape_all(scrape_config_path)
|
||||
|
||||
# Track metrics
|
||||
bench.metric("pages_scraped", len(pages), "pages")
|
||||
|
||||
# Time building
|
||||
with bench.timer("build_skill"):
|
||||
with bench.memory("build_skill"):
|
||||
build_skill(scrape_config_path, pages)
|
||||
|
||||
name = config.get("name", "scraping-benchmark")
|
||||
report = runner.run(name, benchmark_func)
|
||||
|
||||
print(f"\n{report.summary}")
|
||||
|
||||
|
||||
def run_embedding_benchmark(runner, config):
|
||||
"""Run embedding benchmark."""
|
||||
from ..embedding.generator import EmbeddingGenerator
|
||||
|
||||
def benchmark_func(bench: Benchmark):
|
||||
generator = EmbeddingGenerator()
|
||||
|
||||
model = config.get("model", "text-embedding-3-small")
|
||||
texts = config.get("sample_texts", ["Test text"])
|
||||
|
||||
# Single embedding
|
||||
with bench.timer("single_embedding"):
|
||||
generator.generate(texts[0], model=model)
|
||||
|
||||
# Batch embedding
|
||||
if len(texts) > 1:
|
||||
with bench.timer("batch_embedding"):
|
||||
with bench.memory("batch_embedding"):
|
||||
embeddings = generator.generate_batch(texts, model=model)
|
||||
|
||||
bench.metric("embeddings_per_sec", len(embeddings) / bench.result.timings[-1].duration, "emb/sec")
|
||||
|
||||
name = config.get("name", "embedding-benchmark")
|
||||
report = runner.run(name, benchmark_func)
|
||||
|
||||
print(f"\n{report.summary}")
|
||||
|
||||
|
||||
def run_storage_benchmark(runner, config):
|
||||
"""Run storage benchmark."""
|
||||
from .storage import get_storage_adaptor
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
def benchmark_func(bench: Benchmark):
|
||||
provider = config.get("provider", "s3")
|
||||
bucket = config.get("bucket")
|
||||
|
||||
storage = get_storage_adaptor(provider, bucket=bucket)
|
||||
|
||||
# Create test file
|
||||
with NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
f.write("Test data" * 1000)
|
||||
test_file = Path(f.name)
|
||||
|
||||
try:
|
||||
# Upload benchmark
|
||||
with bench.timer("upload"):
|
||||
storage.upload_file(test_file, "benchmark_test.txt")
|
||||
|
||||
# Download benchmark
|
||||
download_path = test_file.parent / "downloaded.txt"
|
||||
with bench.timer("download"):
|
||||
storage.download_file("benchmark_test.txt", download_path)
|
||||
|
||||
# Cleanup
|
||||
storage.delete_file("benchmark_test.txt")
|
||||
download_path.unlink(missing_ok=True)
|
||||
|
||||
finally:
|
||||
test_file.unlink(missing_ok=True)
|
||||
|
||||
name = config.get("name", "storage-benchmark")
|
||||
report = runner.run(name, benchmark_func)
|
||||
|
||||
print(f"\n{report.summary}")
|
||||
|
||||
|
||||
def compare_command(args):
|
||||
"""Compare two benchmarks."""
|
||||
runner = BenchmarkRunner()
|
||||
|
||||
comparison = runner.compare(
|
||||
baseline_path=Path(args.baseline),
|
||||
current_path=Path(args.current)
|
||||
)
|
||||
|
||||
print(f"\n📊 Comparison: {comparison.name}\n")
|
||||
print(f"Overall: {comparison.overall_improvement}\n")
|
||||
|
||||
if comparison.improvements:
|
||||
print("✅ Improvements:")
|
||||
for improvement in comparison.improvements:
|
||||
print(f" • {improvement}")
|
||||
|
||||
if comparison.regressions:
|
||||
print("\n⚠️ Regressions:")
|
||||
for regression in comparison.regressions:
|
||||
print(f" • {regression}")
|
||||
|
||||
if args.fail_on_regression and comparison.has_regressions:
|
||||
print("\n❌ Benchmark failed: regressions detected")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def list_command(args):
|
||||
"""List saved benchmarks."""
|
||||
runner = BenchmarkRunner(output_dir=Path(args.output_dir))
|
||||
|
||||
benchmarks = runner.list_benchmarks()
|
||||
|
||||
if not benchmarks:
|
||||
print("No benchmarks found")
|
||||
return
|
||||
|
||||
print(f"\n📊 Saved benchmarks ({len(benchmarks)}):\n")
|
||||
|
||||
for bench in benchmarks:
|
||||
print(f"• {bench['name']}")
|
||||
print(f" Date: {bench['started_at']}")
|
||||
print(f" Duration: {bench['duration']:.2f}s")
|
||||
print(f" Operations: {bench['operations']}")
|
||||
print(f" Path: {bench['path']}\n")
|
||||
|
||||
|
||||
def show_command(args):
|
||||
"""Show benchmark details."""
|
||||
with open(args.path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
report = BenchmarkReport(**data)
|
||||
|
||||
print(f"\n{report.summary}\n")
|
||||
|
||||
if report.timings:
|
||||
print("⏱️ Timings:")
|
||||
for timing in sorted(report.timings, key=lambda t: t.duration, reverse=True):
|
||||
print(f" • {timing.operation}: {timing.duration:.2f}s")
|
||||
|
||||
if report.memory:
|
||||
print("\n💾 Memory:")
|
||||
for mem in sorted(report.memory, key=lambda m: m.peak_mb, reverse=True):
|
||||
print(f" • {mem.operation}: {mem.peak_mb:.0f}MB peak ({mem.allocated_mb:+.0f}MB)")
|
||||
|
||||
if report.metrics:
|
||||
print("\n📈 Metrics:")
|
||||
for metric in report.metrics:
|
||||
print(f" • {metric.name}: {metric.value:.2f} {metric.unit}")
|
||||
|
||||
if report.recommendations:
|
||||
print("\n💡 Recommendations:")
|
||||
for rec in report.recommendations:
|
||||
print(f" • {rec}")
|
||||
|
||||
|
||||
def cleanup_command(args):
|
||||
"""Cleanup old benchmarks."""
|
||||
runner = BenchmarkRunner(output_dir=Path(args.output_dir))
|
||||
|
||||
runner.cleanup_old(keep_latest=args.keep)
|
||||
|
||||
print("✅ Cleanup complete")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Performance benchmarking suite',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Run scraping benchmark
|
||||
skill-seekers-benchmark run --config benchmarks/scraping.json
|
||||
|
||||
# Compare two benchmarks
|
||||
skill-seekers-benchmark compare \\
|
||||
--baseline benchmarks/v1_20250101.json \\
|
||||
--current benchmarks/v2_20250115.json
|
||||
|
||||
# List all benchmarks
|
||||
skill-seekers-benchmark list
|
||||
|
||||
# Show benchmark details
|
||||
skill-seekers-benchmark show benchmarks/scraping_20250115.json
|
||||
|
||||
# Cleanup old benchmarks
|
||||
skill-seekers-benchmark cleanup --keep 5
|
||||
"""
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest='command', help='Command to execute')
|
||||
|
||||
# Run command
|
||||
run_parser = subparsers.add_parser('run', help='Run benchmark')
|
||||
run_parser.add_argument('--config', required=True, help='Benchmark config file')
|
||||
run_parser.add_argument(
|
||||
'--output-dir', '-o',
|
||||
default='benchmarks',
|
||||
help='Output directory (default: benchmarks)'
|
||||
)
|
||||
|
||||
# Compare command
|
||||
compare_parser = subparsers.add_parser('compare', help='Compare two benchmarks')
|
||||
compare_parser.add_argument('--baseline', required=True, help='Baseline benchmark')
|
||||
compare_parser.add_argument('--current', required=True, help='Current benchmark')
|
||||
compare_parser.add_argument(
|
||||
'--fail-on-regression',
|
||||
action='store_true',
|
||||
help='Exit with error if regressions detected'
|
||||
)
|
||||
|
||||
# List command
|
||||
list_parser = subparsers.add_parser('list', help='List saved benchmarks')
|
||||
list_parser.add_argument(
|
||||
'--output-dir', '-o',
|
||||
default='benchmarks',
|
||||
help='Benchmark directory (default: benchmarks)'
|
||||
)
|
||||
|
||||
# Show command
|
||||
show_parser = subparsers.add_parser('show', help='Show benchmark details')
|
||||
show_parser.add_argument('path', help='Path to benchmark file')
|
||||
|
||||
# Cleanup command
|
||||
cleanup_parser = subparsers.add_parser('cleanup', help='Cleanup old benchmarks')
|
||||
cleanup_parser.add_argument(
|
||||
'--output-dir', '-o',
|
||||
default='benchmarks',
|
||||
help='Benchmark directory (default: benchmarks)'
|
||||
)
|
||||
cleanup_parser.add_argument(
|
||||
'--keep',
|
||||
type=int,
|
||||
default=5,
|
||||
help='Number of latest benchmarks to keep per name (default: 5)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
if args.command == 'run':
|
||||
run_command(args)
|
||||
elif args.command == 'compare':
|
||||
compare_command(args)
|
||||
elif args.command == 'list':
|
||||
list_command(args)
|
||||
elif args.command == 'show':
|
||||
show_command(args)
|
||||
elif args.command == 'cleanup':
|
||||
cleanup_command(args)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
351
src/skill_seekers/cli/cloud_storage_cli.py
Normal file
351
src/skill_seekers/cli/cloud_storage_cli.py
Normal file
@@ -0,0 +1,351 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Cloud storage CLI for Skill Seekers.
|
||||
|
||||
Upload, download, and manage skills in cloud storage (S3, GCS, Azure).
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .storage import get_storage_adaptor
|
||||
|
||||
|
||||
def upload_command(args):
|
||||
"""Handle upload subcommand."""
|
||||
adaptor = get_storage_adaptor(
|
||||
args.provider,
|
||||
bucket=args.bucket,
|
||||
container=args.container,
|
||||
**parse_extra_args(args.extra)
|
||||
)
|
||||
|
||||
if Path(args.local_path).is_dir():
|
||||
print(f"📁 Uploading directory: {args.local_path}")
|
||||
uploaded_files = adaptor.upload_directory(
|
||||
args.local_path,
|
||||
args.remote_path,
|
||||
exclude_patterns=args.exclude
|
||||
)
|
||||
print(f"✅ Uploaded {len(uploaded_files)} files")
|
||||
if args.verbose:
|
||||
for file_path in uploaded_files:
|
||||
print(f" - {file_path}")
|
||||
else:
|
||||
print(f"📄 Uploading file: {args.local_path}")
|
||||
url = adaptor.upload_file(args.local_path, args.remote_path)
|
||||
print(f"✅ Upload complete: {url}")
|
||||
|
||||
|
||||
def download_command(args):
|
||||
"""Handle download subcommand."""
|
||||
adaptor = get_storage_adaptor(
|
||||
args.provider,
|
||||
bucket=args.bucket,
|
||||
container=args.container,
|
||||
**parse_extra_args(args.extra)
|
||||
)
|
||||
|
||||
# Check if remote path is a directory (ends with /)
|
||||
if args.remote_path.endswith('/'):
|
||||
print(f"📁 Downloading directory: {args.remote_path}")
|
||||
downloaded_files = adaptor.download_directory(
|
||||
args.remote_path,
|
||||
args.local_path
|
||||
)
|
||||
print(f"✅ Downloaded {len(downloaded_files)} files")
|
||||
if args.verbose:
|
||||
for file_path in downloaded_files:
|
||||
print(f" - {file_path}")
|
||||
else:
|
||||
print(f"📄 Downloading file: {args.remote_path}")
|
||||
adaptor.download_file(args.remote_path, args.local_path)
|
||||
print(f"✅ Download complete: {args.local_path}")
|
||||
|
||||
|
||||
def list_command(args):
|
||||
"""Handle list subcommand."""
|
||||
adaptor = get_storage_adaptor(
|
||||
args.provider,
|
||||
bucket=args.bucket,
|
||||
container=args.container,
|
||||
**parse_extra_args(args.extra)
|
||||
)
|
||||
|
||||
print(f"📋 Listing files: {args.prefix or '(root)'}")
|
||||
files = adaptor.list_files(args.prefix, args.max_results)
|
||||
|
||||
if not files:
|
||||
print(" (no files found)")
|
||||
return
|
||||
|
||||
print(f"\nFound {len(files)} files:\n")
|
||||
|
||||
# Calculate column widths
|
||||
max_size_width = max(len(format_size(f.size)) for f in files)
|
||||
|
||||
for file_obj in files:
|
||||
size_str = format_size(file_obj.size).rjust(max_size_width)
|
||||
print(f" {size_str} {file_obj.key}")
|
||||
|
||||
if args.verbose and file_obj.last_modified:
|
||||
print(f" Modified: {file_obj.last_modified}")
|
||||
if file_obj.metadata:
|
||||
print(f" Metadata: {file_obj.metadata}")
|
||||
print()
|
||||
|
||||
|
||||
def delete_command(args):
|
||||
"""Handle delete subcommand."""
|
||||
adaptor = get_storage_adaptor(
|
||||
args.provider,
|
||||
bucket=args.bucket,
|
||||
container=args.container,
|
||||
**parse_extra_args(args.extra)
|
||||
)
|
||||
|
||||
if not args.force:
|
||||
response = input(f"⚠️ Delete {args.remote_path}? [y/N]: ")
|
||||
if response.lower() != 'y':
|
||||
print("❌ Deletion cancelled")
|
||||
return
|
||||
|
||||
print(f"🗑️ Deleting: {args.remote_path}")
|
||||
adaptor.delete_file(args.remote_path)
|
||||
print("✅ Deletion complete")
|
||||
|
||||
|
||||
def url_command(args):
|
||||
"""Handle url subcommand."""
|
||||
adaptor = get_storage_adaptor(
|
||||
args.provider,
|
||||
bucket=args.bucket,
|
||||
container=args.container,
|
||||
**parse_extra_args(args.extra)
|
||||
)
|
||||
|
||||
print(f"🔗 Generating signed URL: {args.remote_path}")
|
||||
url = adaptor.get_file_url(args.remote_path, args.expires_in)
|
||||
print(f"\n{url}\n")
|
||||
print(f"⏱️ Expires in: {args.expires_in} seconds ({args.expires_in // 3600}h)")
|
||||
|
||||
|
||||
def copy_command(args):
|
||||
"""Handle copy subcommand."""
|
||||
adaptor = get_storage_adaptor(
|
||||
args.provider,
|
||||
bucket=args.bucket,
|
||||
container=args.container,
|
||||
**parse_extra_args(args.extra)
|
||||
)
|
||||
|
||||
print(f"📋 Copying: {args.source_path} → {args.dest_path}")
|
||||
adaptor.copy_file(args.source_path, args.dest_path)
|
||||
print("✅ Copy complete")
|
||||
|
||||
|
||||
def format_size(size_bytes: int) -> str:
|
||||
"""Format file size in human-readable format."""
|
||||
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||
if size_bytes < 1024.0:
|
||||
return f"{size_bytes:.1f}{unit}"
|
||||
size_bytes /= 1024.0
|
||||
return f"{size_bytes:.1f}PB"
|
||||
|
||||
|
||||
def parse_extra_args(extra: Optional[list]) -> dict:
|
||||
"""Parse extra arguments into dictionary."""
|
||||
if not extra:
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
for arg in extra:
|
||||
if '=' in arg:
|
||||
key, value = arg.split('=', 1)
|
||||
result[key.lstrip('-')] = value
|
||||
else:
|
||||
result[arg.lstrip('-')] = True
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Cloud storage operations for Skill Seekers',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Upload skill to S3
|
||||
skill-seekers-cloud upload --provider s3 --bucket my-bucket \\
|
||||
--local-path output/react/ --remote-path skills/react/
|
||||
|
||||
# Download from GCS
|
||||
skill-seekers-cloud download --provider gcs --bucket my-bucket \\
|
||||
--remote-path skills/react/ --local-path output/react/
|
||||
|
||||
# List files in Azure
|
||||
skill-seekers-cloud list --provider azure --container my-container \\
|
||||
--prefix skills/
|
||||
|
||||
# Generate signed URL
|
||||
skill-seekers-cloud url --provider s3 --bucket my-bucket \\
|
||||
--remote-path skills/react.zip --expires-in 7200
|
||||
|
||||
Provider-specific options:
|
||||
S3: --region=us-west-2 --endpoint-url=https://...
|
||||
GCS: --project=my-project --credentials-path=/path/to/creds.json
|
||||
Azure: --account-name=myaccount --account-key=...
|
||||
"""
|
||||
)
|
||||
|
||||
# Global arguments
|
||||
parser.add_argument(
|
||||
'--provider',
|
||||
choices=['s3', 'gcs', 'azure'],
|
||||
required=True,
|
||||
help='Cloud storage provider'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--bucket',
|
||||
help='S3/GCS bucket name (for S3/GCS)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--container',
|
||||
help='Azure container name (for Azure)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Verbose output'
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest='command', help='Command to execute')
|
||||
|
||||
# Upload command
|
||||
upload_parser = subparsers.add_parser('upload', help='Upload file or directory')
|
||||
upload_parser.add_argument('local_path', help='Local file or directory path')
|
||||
upload_parser.add_argument('remote_path', help='Remote path in cloud storage')
|
||||
upload_parser.add_argument(
|
||||
'--exclude',
|
||||
action='append',
|
||||
help='Glob patterns to exclude (for directories)'
|
||||
)
|
||||
upload_parser.add_argument(
|
||||
'extra',
|
||||
nargs='*',
|
||||
help='Provider-specific options (--key=value)'
|
||||
)
|
||||
|
||||
# Download command
|
||||
download_parser = subparsers.add_parser('download', help='Download file or directory')
|
||||
download_parser.add_argument('remote_path', help='Remote path in cloud storage')
|
||||
download_parser.add_argument('local_path', help='Local destination path')
|
||||
download_parser.add_argument(
|
||||
'extra',
|
||||
nargs='*',
|
||||
help='Provider-specific options (--key=value)'
|
||||
)
|
||||
|
||||
# List command
|
||||
list_parser = subparsers.add_parser('list', help='List files in cloud storage')
|
||||
list_parser.add_argument(
|
||||
'--prefix',
|
||||
default='',
|
||||
help='Prefix to filter files'
|
||||
)
|
||||
list_parser.add_argument(
|
||||
'--max-results',
|
||||
type=int,
|
||||
default=1000,
|
||||
help='Maximum number of results'
|
||||
)
|
||||
list_parser.add_argument(
|
||||
'extra',
|
||||
nargs='*',
|
||||
help='Provider-specific options (--key=value)'
|
||||
)
|
||||
|
||||
# Delete command
|
||||
delete_parser = subparsers.add_parser('delete', help='Delete file from cloud storage')
|
||||
delete_parser.add_argument('remote_path', help='Remote path in cloud storage')
|
||||
delete_parser.add_argument(
|
||||
'--force', '-f',
|
||||
action='store_true',
|
||||
help='Skip confirmation prompt'
|
||||
)
|
||||
delete_parser.add_argument(
|
||||
'extra',
|
||||
nargs='*',
|
||||
help='Provider-specific options (--key=value)'
|
||||
)
|
||||
|
||||
# URL command
|
||||
url_parser = subparsers.add_parser('url', help='Generate signed URL')
|
||||
url_parser.add_argument('remote_path', help='Remote path in cloud storage')
|
||||
url_parser.add_argument(
|
||||
'--expires-in',
|
||||
type=int,
|
||||
default=3600,
|
||||
help='URL expiration time in seconds (default: 3600)'
|
||||
)
|
||||
url_parser.add_argument(
|
||||
'extra',
|
||||
nargs='*',
|
||||
help='Provider-specific options (--key=value)'
|
||||
)
|
||||
|
||||
# Copy command
|
||||
copy_parser = subparsers.add_parser('copy', help='Copy file within cloud storage')
|
||||
copy_parser.add_argument('source_path', help='Source path')
|
||||
copy_parser.add_argument('dest_path', help='Destination path')
|
||||
copy_parser.add_argument(
|
||||
'extra',
|
||||
nargs='*',
|
||||
help='Provider-specific options (--key=value)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
# Validate bucket/container based on provider
|
||||
if args.provider in ['s3', 'gcs'] and not args.bucket:
|
||||
print(f"❌ Error: --bucket is required for {args.provider.upper()}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
elif args.provider == 'azure' and not args.container:
|
||||
print("❌ Error: --container is required for Azure", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
# Execute command
|
||||
if args.command == 'upload':
|
||||
upload_command(args)
|
||||
elif args.command == 'download':
|
||||
download_command(args)
|
||||
elif args.command == 'list':
|
||||
list_command(args)
|
||||
elif args.command == 'delete':
|
||||
delete_command(args)
|
||||
elif args.command == 'url':
|
||||
url_command(args)
|
||||
elif args.command == 'copy':
|
||||
copy_command(args)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"❌ Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}", file=sys.stderr)
|
||||
if args.verbose:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -206,8 +206,9 @@ class RAGChunker:
|
||||
code_blocks = []
|
||||
placeholder_pattern = "<<CODE_BLOCK_{idx}>>"
|
||||
|
||||
# Match code blocks (both ``` and indented)
|
||||
code_block_pattern = r'```[\s\S]*?```|(?:^|\n)(?: {4}|\t).+(?:\n(?: {4}|\t).+)*'
|
||||
# Match code blocks (``` fenced blocks)
|
||||
# Use DOTALL flag to match across newlines
|
||||
code_block_pattern = r'```[^\n]*\n.*?```'
|
||||
|
||||
def replacer(match):
|
||||
idx = len(code_blocks)
|
||||
@@ -219,7 +220,12 @@ class RAGChunker:
|
||||
})
|
||||
return placeholder_pattern.format(idx=idx)
|
||||
|
||||
text_with_placeholders = re.sub(code_block_pattern, replacer, text)
|
||||
text_with_placeholders = re.sub(
|
||||
code_block_pattern,
|
||||
replacer,
|
||||
text,
|
||||
flags=re.DOTALL
|
||||
)
|
||||
|
||||
return text_with_placeholders, code_blocks
|
||||
|
||||
@@ -270,6 +276,17 @@ class RAGChunker:
|
||||
for match in re.finditer(r'\n#{1,6}\s+.+\n', text):
|
||||
boundaries.append(match.start())
|
||||
|
||||
# Single newlines (less preferred, but useful)
|
||||
for match in re.finditer(r'\n', text):
|
||||
boundaries.append(match.start())
|
||||
|
||||
# If we have very few boundaries, add artificial ones
|
||||
# (for text without natural boundaries like "AAA...")
|
||||
if len(boundaries) < 3:
|
||||
target_size_chars = self.chunk_size * self.chars_per_token
|
||||
for i in range(target_size_chars, len(text), target_size_chars):
|
||||
boundaries.append(i)
|
||||
|
||||
# End is always a boundary
|
||||
boundaries.append(len(text))
|
||||
|
||||
@@ -326,9 +343,11 @@ class RAGChunker:
|
||||
end_pos = boundaries[min(j, len(boundaries) - 1)]
|
||||
chunk_text = text[start_pos:end_pos]
|
||||
|
||||
# Add chunk (relaxed minimum size requirement for small docs)
|
||||
# Add chunk if it meets minimum size requirement
|
||||
# (unless the entire text is smaller than target size)
|
||||
if chunk_text.strip():
|
||||
chunks.append(chunk_text)
|
||||
if len(text) <= target_size_chars or len(chunk_text) >= min_size_chars:
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Move to next chunk with overlap
|
||||
if j < len(boundaries) - 1:
|
||||
|
||||
85
src/skill_seekers/cli/storage/__init__.py
Normal file
85
src/skill_seekers/cli/storage/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Cloud storage adaptors for Skill Seekers.
|
||||
|
||||
Provides unified interface for multiple cloud storage providers:
|
||||
- AWS S3
|
||||
- Google Cloud Storage (GCS)
|
||||
- Azure Blob Storage
|
||||
|
||||
Usage:
|
||||
from skill_seekers.cli.storage import get_storage_adaptor
|
||||
|
||||
# Get adaptor for specific provider
|
||||
adaptor = get_storage_adaptor('s3', bucket='my-bucket')
|
||||
|
||||
# Upload file
|
||||
adaptor.upload_file('local/path/skill.zip', 'skills/skill.zip')
|
||||
|
||||
# Download file
|
||||
adaptor.download_file('skills/skill.zip', 'local/path/skill.zip')
|
||||
|
||||
# List files
|
||||
files = adaptor.list_files('skills/')
|
||||
"""
|
||||
|
||||
from .base_storage import BaseStorageAdaptor, StorageObject
|
||||
from .s3_storage import S3StorageAdaptor
|
||||
from .gcs_storage import GCSStorageAdaptor
|
||||
from .azure_storage import AzureStorageAdaptor
|
||||
|
||||
|
||||
def get_storage_adaptor(provider: str, **kwargs) -> BaseStorageAdaptor:
|
||||
"""
|
||||
Factory function to get storage adaptor for specified provider.
|
||||
|
||||
Args:
|
||||
provider: Storage provider name ('s3', 'gcs', 'azure')
|
||||
**kwargs: Provider-specific configuration
|
||||
|
||||
Returns:
|
||||
Storage adaptor instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported
|
||||
|
||||
Examples:
|
||||
# AWS S3
|
||||
adaptor = get_storage_adaptor('s3',
|
||||
bucket='my-bucket',
|
||||
region='us-west-2')
|
||||
|
||||
# Google Cloud Storage
|
||||
adaptor = get_storage_adaptor('gcs',
|
||||
bucket='my-bucket',
|
||||
project='my-project')
|
||||
|
||||
# Azure Blob Storage
|
||||
adaptor = get_storage_adaptor('azure',
|
||||
container='my-container',
|
||||
account_name='myaccount')
|
||||
"""
|
||||
adaptors = {
|
||||
's3': S3StorageAdaptor,
|
||||
'gcs': GCSStorageAdaptor,
|
||||
'azure': AzureStorageAdaptor,
|
||||
}
|
||||
|
||||
provider_lower = provider.lower()
|
||||
if provider_lower not in adaptors:
|
||||
supported = ', '.join(adaptors.keys())
|
||||
raise ValueError(
|
||||
f"Unsupported storage provider: {provider}. "
|
||||
f"Supported providers: {supported}"
|
||||
)
|
||||
|
||||
return adaptors[provider_lower](**kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'BaseStorageAdaptor',
|
||||
'StorageObject',
|
||||
'S3StorageAdaptor',
|
||||
'GCSStorageAdaptor',
|
||||
'AzureStorageAdaptor',
|
||||
'get_storage_adaptor',
|
||||
]
|
||||
254
src/skill_seekers/cli/storage/azure_storage.py
Normal file
254
src/skill_seekers/cli/storage/azure_storage.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Azure Blob Storage adaptor implementation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
try:
|
||||
from azure.storage.blob import BlobServiceClient, BlobSasPermissions, generate_blob_sas
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
AZURE_AVAILABLE = True
|
||||
except ImportError:
|
||||
AZURE_AVAILABLE = False
|
||||
|
||||
from .base_storage import BaseStorageAdaptor, StorageObject
|
||||
|
||||
|
||||
class AzureStorageAdaptor(BaseStorageAdaptor):
|
||||
"""
|
||||
Azure Blob Storage adaptor.
|
||||
|
||||
Configuration:
|
||||
container: Azure container name (required)
|
||||
account_name: Storage account name (optional, uses env)
|
||||
account_key: Storage account key (optional, uses env)
|
||||
connection_string: Connection string (optional, alternative to account_name/key)
|
||||
|
||||
Environment Variables:
|
||||
AZURE_STORAGE_CONNECTION_STRING: Azure storage connection string
|
||||
AZURE_STORAGE_ACCOUNT_NAME: Storage account name
|
||||
AZURE_STORAGE_ACCOUNT_KEY: Storage account key
|
||||
|
||||
Examples:
|
||||
# Using connection string
|
||||
adaptor = AzureStorageAdaptor(
|
||||
container='my-container',
|
||||
connection_string='DefaultEndpointsProtocol=https;...'
|
||||
)
|
||||
|
||||
# Using account name and key
|
||||
adaptor = AzureStorageAdaptor(
|
||||
container='my-container',
|
||||
account_name='myaccount',
|
||||
account_key='mykey'
|
||||
)
|
||||
|
||||
# Using environment variables
|
||||
adaptor = AzureStorageAdaptor(container='my-container')
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize Azure storage adaptor.
|
||||
|
||||
Args:
|
||||
container: Azure container name (required)
|
||||
**kwargs: Additional Azure configuration
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not AZURE_AVAILABLE:
|
||||
raise ImportError(
|
||||
"azure-storage-blob is required for Azure storage. "
|
||||
"Install with: pip install azure-storage-blob"
|
||||
)
|
||||
|
||||
if 'container' not in kwargs:
|
||||
raise ValueError("container parameter is required for Azure storage")
|
||||
|
||||
self.container_name = kwargs['container']
|
||||
|
||||
# Initialize BlobServiceClient
|
||||
if 'connection_string' in kwargs:
|
||||
connection_string = kwargs['connection_string']
|
||||
else:
|
||||
connection_string = os.getenv('AZURE_STORAGE_CONNECTION_STRING')
|
||||
|
||||
if connection_string:
|
||||
self.blob_service_client = BlobServiceClient.from_connection_string(
|
||||
connection_string
|
||||
)
|
||||
# Extract account name from connection string
|
||||
self.account_name = None
|
||||
self.account_key = None
|
||||
for part in connection_string.split(';'):
|
||||
if part.startswith('AccountName='):
|
||||
self.account_name = part.split('=', 1)[1]
|
||||
elif part.startswith('AccountKey='):
|
||||
self.account_key = part.split('=', 1)[1]
|
||||
else:
|
||||
account_name = kwargs.get(
|
||||
'account_name',
|
||||
os.getenv('AZURE_STORAGE_ACCOUNT_NAME')
|
||||
)
|
||||
account_key = kwargs.get(
|
||||
'account_key',
|
||||
os.getenv('AZURE_STORAGE_ACCOUNT_KEY')
|
||||
)
|
||||
|
||||
if not account_name or not account_key:
|
||||
raise ValueError(
|
||||
"Either connection_string or (account_name + account_key) "
|
||||
"must be provided for Azure storage"
|
||||
)
|
||||
|
||||
self.account_name = account_name
|
||||
self.account_key = account_key
|
||||
account_url = f"https://{account_name}.blob.core.windows.net"
|
||||
self.blob_service_client = BlobServiceClient(
|
||||
account_url=account_url,
|
||||
credential=account_key
|
||||
)
|
||||
|
||||
self.container_client = self.blob_service_client.get_container_client(
|
||||
self.container_name
|
||||
)
|
||||
|
||||
def upload_file(
|
||||
self, local_path: str, remote_path: str, metadata: Optional[Dict[str, str]] = None
|
||||
) -> str:
|
||||
"""Upload file to Azure Blob Storage."""
|
||||
local_file = Path(local_path)
|
||||
if not local_file.exists():
|
||||
raise FileNotFoundError(f"Local file not found: {local_path}")
|
||||
|
||||
try:
|
||||
blob_client = self.container_client.get_blob_client(remote_path)
|
||||
|
||||
with open(local_file, "rb") as data:
|
||||
blob_client.upload_blob(
|
||||
data,
|
||||
overwrite=True,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
return f"https://{self.account_name}.blob.core.windows.net/{self.container_name}/{remote_path}"
|
||||
except Exception as e:
|
||||
raise Exception(f"Azure upload failed: {e}")
|
||||
|
||||
def download_file(self, remote_path: str, local_path: str) -> None:
|
||||
"""Download file from Azure Blob Storage."""
|
||||
local_file = Path(local_path)
|
||||
local_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
blob_client = self.container_client.get_blob_client(remote_path)
|
||||
|
||||
with open(local_file, "wb") as download_file:
|
||||
download_stream = blob_client.download_blob()
|
||||
download_file.write(download_stream.readall())
|
||||
except ResourceNotFoundError:
|
||||
raise FileNotFoundError(f"Remote file not found: {remote_path}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Azure download failed: {e}")
|
||||
|
||||
def delete_file(self, remote_path: str) -> None:
|
||||
"""Delete file from Azure Blob Storage."""
|
||||
try:
|
||||
blob_client = self.container_client.get_blob_client(remote_path)
|
||||
blob_client.delete_blob()
|
||||
except ResourceNotFoundError:
|
||||
raise FileNotFoundError(f"Remote file not found: {remote_path}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Azure deletion failed: {e}")
|
||||
|
||||
def list_files(
|
||||
self, prefix: str = "", max_results: int = 1000
|
||||
) -> List[StorageObject]:
|
||||
"""List files in Azure container."""
|
||||
try:
|
||||
blobs = self.container_client.list_blobs(
|
||||
name_starts_with=prefix,
|
||||
results_per_page=max_results
|
||||
)
|
||||
|
||||
files = []
|
||||
for blob in blobs:
|
||||
files.append(StorageObject(
|
||||
key=blob.name,
|
||||
size=blob.size,
|
||||
last_modified=blob.last_modified.isoformat() if blob.last_modified else None,
|
||||
etag=blob.etag,
|
||||
metadata=blob.metadata
|
||||
))
|
||||
|
||||
return files
|
||||
except Exception as e:
|
||||
raise Exception(f"Azure listing failed: {e}")
|
||||
|
||||
def file_exists(self, remote_path: str) -> bool:
|
||||
"""Check if file exists in Azure Blob Storage."""
|
||||
try:
|
||||
blob_client = self.container_client.get_blob_client(remote_path)
|
||||
return blob_client.exists()
|
||||
except Exception as e:
|
||||
raise Exception(f"Azure file existence check failed: {e}")
|
||||
|
||||
def get_file_url(self, remote_path: str, expires_in: int = 3600) -> str:
|
||||
"""Generate SAS URL for Azure blob."""
|
||||
try:
|
||||
blob_client = self.container_client.get_blob_client(remote_path)
|
||||
|
||||
if not blob_client.exists():
|
||||
raise FileNotFoundError(f"Remote file not found: {remote_path}")
|
||||
|
||||
if not self.account_name or not self.account_key:
|
||||
raise ValueError(
|
||||
"Account name and key are required for SAS URL generation"
|
||||
)
|
||||
|
||||
sas_token = generate_blob_sas(
|
||||
account_name=self.account_name,
|
||||
container_name=self.container_name,
|
||||
blob_name=remote_path,
|
||||
account_key=self.account_key,
|
||||
permission=BlobSasPermissions(read=True),
|
||||
expiry=datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
)
|
||||
|
||||
return f"{blob_client.url}?{sas_token}"
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise Exception(f"Azure SAS URL generation failed: {e}")
|
||||
|
||||
def copy_file(self, source_path: str, dest_path: str) -> None:
|
||||
"""Copy file within Azure container (server-side copy)."""
|
||||
try:
|
||||
source_blob = self.container_client.get_blob_client(source_path)
|
||||
|
||||
if not source_blob.exists():
|
||||
raise FileNotFoundError(f"Source file not found: {source_path}")
|
||||
|
||||
dest_blob = self.container_client.get_blob_client(dest_path)
|
||||
|
||||
# Start copy operation
|
||||
dest_blob.start_copy_from_url(source_blob.url)
|
||||
|
||||
# Wait for copy to complete
|
||||
properties = dest_blob.get_blob_properties()
|
||||
while properties.copy.status == 'pending':
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
properties = dest_blob.get_blob_properties()
|
||||
|
||||
if properties.copy.status != 'success':
|
||||
raise Exception(f"Copy failed with status: {properties.copy.status}")
|
||||
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise Exception(f"Azure copy failed: {e}")
|
||||
275
src/skill_seekers/cli/storage/base_storage.py
Normal file
275
src/skill_seekers/cli/storage/base_storage.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Base storage adaptor interface for cloud storage providers.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageObject:
|
||||
"""
|
||||
Represents a file/object in cloud storage.
|
||||
|
||||
Attributes:
|
||||
key: Object key/path in storage
|
||||
size: Size in bytes
|
||||
last_modified: Last modification timestamp
|
||||
etag: ETag/hash of object
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
|
||||
key: str
|
||||
size: int
|
||||
last_modified: Optional[str] = None
|
||||
etag: Optional[str] = None
|
||||
metadata: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
class BaseStorageAdaptor(ABC):
|
||||
"""
|
||||
Abstract base class for cloud storage adaptors.
|
||||
|
||||
Provides unified interface for different cloud storage providers.
|
||||
All adaptors must implement these methods.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize storage adaptor.
|
||||
|
||||
Args:
|
||||
**kwargs: Provider-specific configuration
|
||||
"""
|
||||
self.config = kwargs
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(
|
||||
self, local_path: str, remote_path: str, metadata: Optional[Dict[str, str]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Upload file to cloud storage.
|
||||
|
||||
Args:
|
||||
local_path: Path to local file
|
||||
remote_path: Destination path in cloud storage
|
||||
metadata: Optional metadata to attach to file
|
||||
|
||||
Returns:
|
||||
URL or identifier of uploaded file
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If local file doesn't exist
|
||||
Exception: If upload fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def download_file(self, remote_path: str, local_path: str) -> None:
|
||||
"""
|
||||
Download file from cloud storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to file in cloud storage
|
||||
local_path: Destination path for downloaded file
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If remote file doesn't exist
|
||||
Exception: If download fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, remote_path: str) -> None:
|
||||
"""
|
||||
Delete file from cloud storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to file in cloud storage
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If remote file doesn't exist
|
||||
Exception: If deletion fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_files(
|
||||
self, prefix: str = "", max_results: int = 1000
|
||||
) -> List[StorageObject]:
|
||||
"""
|
||||
List files in cloud storage.
|
||||
|
||||
Args:
|
||||
prefix: Prefix to filter files (directory path)
|
||||
max_results: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of StorageObject instances
|
||||
|
||||
Raises:
|
||||
Exception: If listing fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def file_exists(self, remote_path: str) -> bool:
|
||||
"""
|
||||
Check if file exists in cloud storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to file in cloud storage
|
||||
|
||||
Returns:
|
||||
True if file exists, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_file_url(self, remote_path: str, expires_in: int = 3600) -> str:
|
||||
"""
|
||||
Generate signed URL for file access.
|
||||
|
||||
Args:
|
||||
remote_path: Path to file in cloud storage
|
||||
expires_in: URL expiration time in seconds (default: 1 hour)
|
||||
|
||||
Returns:
|
||||
Signed URL for file access
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If remote file doesn't exist
|
||||
Exception: If URL generation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
def upload_directory(
|
||||
self, local_dir: str, remote_prefix: str = "", exclude_patterns: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Upload entire directory to cloud storage.
|
||||
|
||||
Args:
|
||||
local_dir: Path to local directory
|
||||
remote_prefix: Prefix for uploaded files
|
||||
exclude_patterns: Glob patterns to exclude files
|
||||
|
||||
Returns:
|
||||
List of uploaded file paths
|
||||
|
||||
Raises:
|
||||
NotADirectoryError: If local_dir is not a directory
|
||||
Exception: If upload fails
|
||||
"""
|
||||
local_path = Path(local_dir)
|
||||
if not local_path.is_dir():
|
||||
raise NotADirectoryError(f"Not a directory: {local_dir}")
|
||||
|
||||
uploaded_files = []
|
||||
exclude_patterns = exclude_patterns or []
|
||||
|
||||
for file_path in local_path.rglob("*"):
|
||||
if file_path.is_file():
|
||||
# Check exclusion patterns
|
||||
should_exclude = False
|
||||
for pattern in exclude_patterns:
|
||||
if file_path.match(pattern):
|
||||
should_exclude = True
|
||||
break
|
||||
|
||||
if should_exclude:
|
||||
continue
|
||||
|
||||
# Calculate relative path
|
||||
relative_path = file_path.relative_to(local_path)
|
||||
remote_path = f"{remote_prefix}/{relative_path}".lstrip("/")
|
||||
|
||||
# Upload file
|
||||
self.upload_file(str(file_path), remote_path)
|
||||
uploaded_files.append(remote_path)
|
||||
|
||||
return uploaded_files
|
||||
|
||||
def download_directory(
|
||||
self, remote_prefix: str, local_dir: str
|
||||
) -> List[str]:
|
||||
"""
|
||||
Download directory from cloud storage.
|
||||
|
||||
Args:
|
||||
remote_prefix: Prefix of files to download
|
||||
local_dir: Destination directory
|
||||
|
||||
Returns:
|
||||
List of downloaded file paths
|
||||
|
||||
Raises:
|
||||
Exception: If download fails
|
||||
"""
|
||||
local_path = Path(local_dir)
|
||||
local_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
downloaded_files = []
|
||||
files = self.list_files(prefix=remote_prefix)
|
||||
|
||||
for file_obj in files:
|
||||
# Calculate local path
|
||||
relative_path = file_obj.key.removeprefix(remote_prefix).lstrip("/")
|
||||
local_file_path = local_path / relative_path
|
||||
|
||||
# Create parent directories
|
||||
local_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download file
|
||||
self.download_file(file_obj.key, str(local_file_path))
|
||||
downloaded_files.append(str(local_file_path))
|
||||
|
||||
return downloaded_files
|
||||
|
||||
def get_file_size(self, remote_path: str) -> int:
|
||||
"""
|
||||
Get size of file in cloud storage.
|
||||
|
||||
Args:
|
||||
remote_path: Path to file in cloud storage
|
||||
|
||||
Returns:
|
||||
File size in bytes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If remote file doesn't exist
|
||||
"""
|
||||
files = self.list_files(prefix=remote_path, max_results=1)
|
||||
if not files or files[0].key != remote_path:
|
||||
raise FileNotFoundError(f"File not found: {remote_path}")
|
||||
return files[0].size
|
||||
|
||||
def copy_file(
|
||||
self, source_path: str, dest_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Copy file within cloud storage.
|
||||
|
||||
Default implementation downloads then uploads.
|
||||
Subclasses can override with provider-specific copy operations.
|
||||
|
||||
Args:
|
||||
source_path: Source file path
|
||||
dest_path: Destination file path
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If source file doesn't exist
|
||||
Exception: If copy fails
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
try:
|
||||
self.download_file(source_path, tmp_path)
|
||||
self.upload_file(tmp_path, dest_path)
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
194
src/skill_seekers/cli/storage/gcs_storage.py
Normal file
194
src/skill_seekers/cli/storage/gcs_storage.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Google Cloud Storage (GCS) adaptor implementation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import timedelta
|
||||
|
||||
try:
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
GCS_AVAILABLE = True
|
||||
except ImportError:
|
||||
GCS_AVAILABLE = False
|
||||
|
||||
from .base_storage import BaseStorageAdaptor, StorageObject
|
||||
|
||||
|
||||
class GCSStorageAdaptor(BaseStorageAdaptor):
|
||||
"""
|
||||
Google Cloud Storage adaptor.
|
||||
|
||||
Configuration:
|
||||
bucket: GCS bucket name (required)
|
||||
project: GCP project ID (optional, uses default)
|
||||
credentials_path: Path to service account JSON (optional)
|
||||
|
||||
Environment Variables:
|
||||
GOOGLE_APPLICATION_CREDENTIALS: Path to service account JSON
|
||||
GOOGLE_CLOUD_PROJECT: GCP project ID
|
||||
|
||||
Examples:
|
||||
# Using environment variables
|
||||
adaptor = GCSStorageAdaptor(bucket='my-bucket')
|
||||
|
||||
# With explicit credentials
|
||||
adaptor = GCSStorageAdaptor(
|
||||
bucket='my-bucket',
|
||||
project='my-project',
|
||||
credentials_path='/path/to/credentials.json'
|
||||
)
|
||||
|
||||
# Using default credentials
|
||||
adaptor = GCSStorageAdaptor(
|
||||
bucket='my-bucket',
|
||||
project='my-project'
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize GCS storage adaptor.
|
||||
|
||||
Args:
|
||||
bucket: GCS bucket name (required)
|
||||
**kwargs: Additional GCS configuration
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not GCS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"google-cloud-storage is required for GCS storage. "
|
||||
"Install with: pip install google-cloud-storage"
|
||||
)
|
||||
|
||||
if 'bucket' not in kwargs:
|
||||
raise ValueError("bucket parameter is required for GCS storage")
|
||||
|
||||
self.bucket_name = kwargs['bucket']
|
||||
self.project = kwargs.get('project', os.getenv('GOOGLE_CLOUD_PROJECT'))
|
||||
|
||||
# Initialize GCS client
|
||||
client_kwargs = {}
|
||||
if self.project:
|
||||
client_kwargs['project'] = self.project
|
||||
|
||||
if 'credentials_path' in kwargs:
|
||||
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = kwargs['credentials_path']
|
||||
|
||||
self.storage_client = storage.Client(**client_kwargs)
|
||||
self.bucket = self.storage_client.bucket(self.bucket_name)
|
||||
|
||||
def upload_file(
|
||||
self, local_path: str, remote_path: str, metadata: Optional[Dict[str, str]] = None
|
||||
) -> str:
|
||||
"""Upload file to GCS."""
|
||||
local_file = Path(local_path)
|
||||
if not local_file.exists():
|
||||
raise FileNotFoundError(f"Local file not found: {local_path}")
|
||||
|
||||
try:
|
||||
blob = self.bucket.blob(remote_path)
|
||||
|
||||
if metadata:
|
||||
blob.metadata = metadata
|
||||
|
||||
blob.upload_from_filename(str(local_file))
|
||||
return f"gs://{self.bucket_name}/{remote_path}"
|
||||
except Exception as e:
|
||||
raise Exception(f"GCS upload failed: {e}")
|
||||
|
||||
def download_file(self, remote_path: str, local_path: str) -> None:
|
||||
"""Download file from GCS."""
|
||||
local_file = Path(local_path)
|
||||
local_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
blob = self.bucket.blob(remote_path)
|
||||
blob.download_to_filename(str(local_file))
|
||||
except NotFound:
|
||||
raise FileNotFoundError(f"Remote file not found: {remote_path}")
|
||||
except Exception as e:
|
||||
raise Exception(f"GCS download failed: {e}")
|
||||
|
||||
def delete_file(self, remote_path: str) -> None:
|
||||
"""Delete file from GCS."""
|
||||
try:
|
||||
blob = self.bucket.blob(remote_path)
|
||||
blob.delete()
|
||||
except NotFound:
|
||||
raise FileNotFoundError(f"Remote file not found: {remote_path}")
|
||||
except Exception as e:
|
||||
raise Exception(f"GCS deletion failed: {e}")
|
||||
|
||||
def list_files(
|
||||
self, prefix: str = "", max_results: int = 1000
|
||||
) -> List[StorageObject]:
|
||||
"""List files in GCS bucket."""
|
||||
try:
|
||||
blobs = self.storage_client.list_blobs(
|
||||
self.bucket_name,
|
||||
prefix=prefix,
|
||||
max_results=max_results
|
||||
)
|
||||
|
||||
files = []
|
||||
for blob in blobs:
|
||||
files.append(StorageObject(
|
||||
key=blob.name,
|
||||
size=blob.size,
|
||||
last_modified=blob.updated.isoformat() if blob.updated else None,
|
||||
etag=blob.etag,
|
||||
metadata=blob.metadata
|
||||
))
|
||||
|
||||
return files
|
||||
except Exception as e:
|
||||
raise Exception(f"GCS listing failed: {e}")
|
||||
|
||||
def file_exists(self, remote_path: str) -> bool:
|
||||
"""Check if file exists in GCS."""
|
||||
try:
|
||||
blob = self.bucket.blob(remote_path)
|
||||
return blob.exists()
|
||||
except Exception as e:
|
||||
raise Exception(f"GCS file existence check failed: {e}")
|
||||
|
||||
def get_file_url(self, remote_path: str, expires_in: int = 3600) -> str:
|
||||
"""Generate signed URL for GCS object."""
|
||||
try:
|
||||
blob = self.bucket.blob(remote_path)
|
||||
|
||||
if not blob.exists():
|
||||
raise FileNotFoundError(f"Remote file not found: {remote_path}")
|
||||
|
||||
url = blob.generate_signed_url(
|
||||
version="v4",
|
||||
expiration=timedelta(seconds=expires_in),
|
||||
method="GET"
|
||||
)
|
||||
return url
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise Exception(f"GCS signed URL generation failed: {e}")
|
||||
|
||||
def copy_file(self, source_path: str, dest_path: str) -> None:
|
||||
"""Copy file within GCS bucket (server-side copy)."""
|
||||
try:
|
||||
source_blob = self.bucket.blob(source_path)
|
||||
|
||||
if not source_blob.exists():
|
||||
raise FileNotFoundError(f"Source file not found: {source_path}")
|
||||
|
||||
self.bucket.copy_blob(
|
||||
source_blob,
|
||||
self.bucket,
|
||||
dest_path
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise Exception(f"GCS copy failed: {e}")
|
||||
216
src/skill_seekers/cli/storage/s3_storage.py
Normal file
216
src/skill_seekers/cli/storage/s3_storage.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
AWS S3 storage adaptor implementation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
BOTO3_AVAILABLE = True
|
||||
except ImportError:
|
||||
BOTO3_AVAILABLE = False
|
||||
|
||||
from .base_storage import BaseStorageAdaptor, StorageObject
|
||||
|
||||
|
||||
class S3StorageAdaptor(BaseStorageAdaptor):
|
||||
"""
|
||||
AWS S3 storage adaptor.
|
||||
|
||||
Configuration:
|
||||
bucket: S3 bucket name (required)
|
||||
region: AWS region (optional, default: us-east-1)
|
||||
aws_access_key_id: AWS access key (optional, uses env/credentials)
|
||||
aws_secret_access_key: AWS secret key (optional, uses env/credentials)
|
||||
endpoint_url: Custom endpoint URL (optional, for S3-compatible services)
|
||||
|
||||
Environment Variables:
|
||||
AWS_ACCESS_KEY_ID: AWS access key
|
||||
AWS_SECRET_ACCESS_KEY: AWS secret key
|
||||
AWS_DEFAULT_REGION: AWS region
|
||||
|
||||
Examples:
|
||||
# Using environment variables
|
||||
adaptor = S3StorageAdaptor(bucket='my-bucket')
|
||||
|
||||
# With explicit credentials
|
||||
adaptor = S3StorageAdaptor(
|
||||
bucket='my-bucket',
|
||||
region='us-west-2',
|
||||
aws_access_key_id='AKIAIOSFODNN7EXAMPLE',
|
||||
aws_secret_access_key='wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY'
|
||||
)
|
||||
|
||||
# S3-compatible service (MinIO, DigitalOcean Spaces)
|
||||
adaptor = S3StorageAdaptor(
|
||||
bucket='my-bucket',
|
||||
endpoint_url='https://nyc3.digitaloceanspaces.com',
|
||||
aws_access_key_id='...',
|
||||
aws_secret_access_key='...'
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize S3 storage adaptor.
|
||||
|
||||
Args:
|
||||
bucket: S3 bucket name (required)
|
||||
**kwargs: Additional S3 configuration
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not BOTO3_AVAILABLE:
|
||||
raise ImportError(
|
||||
"boto3 is required for S3 storage. "
|
||||
"Install with: pip install boto3"
|
||||
)
|
||||
|
||||
if 'bucket' not in kwargs:
|
||||
raise ValueError("bucket parameter is required for S3 storage")
|
||||
|
||||
self.bucket = kwargs['bucket']
|
||||
self.region = kwargs.get('region', os.getenv('AWS_DEFAULT_REGION', 'us-east-1'))
|
||||
|
||||
# Initialize S3 client
|
||||
client_kwargs = {
|
||||
'region_name': self.region,
|
||||
}
|
||||
|
||||
if 'endpoint_url' in kwargs:
|
||||
client_kwargs['endpoint_url'] = kwargs['endpoint_url']
|
||||
|
||||
if 'aws_access_key_id' in kwargs:
|
||||
client_kwargs['aws_access_key_id'] = kwargs['aws_access_key_id']
|
||||
|
||||
if 'aws_secret_access_key' in kwargs:
|
||||
client_kwargs['aws_secret_access_key'] = kwargs['aws_secret_access_key']
|
||||
|
||||
self.s3_client = boto3.client('s3', **client_kwargs)
|
||||
self.s3_resource = boto3.resource('s3', **client_kwargs)
|
||||
|
||||
def upload_file(
|
||||
self, local_path: str, remote_path: str, metadata: Optional[Dict[str, str]] = None
|
||||
) -> str:
|
||||
"""Upload file to S3."""
|
||||
local_file = Path(local_path)
|
||||
if not local_file.exists():
|
||||
raise FileNotFoundError(f"Local file not found: {local_path}")
|
||||
|
||||
extra_args = {}
|
||||
if metadata:
|
||||
extra_args['Metadata'] = metadata
|
||||
|
||||
try:
|
||||
self.s3_client.upload_file(
|
||||
str(local_file),
|
||||
self.bucket,
|
||||
remote_path,
|
||||
ExtraArgs=extra_args if extra_args else None
|
||||
)
|
||||
return f"s3://{self.bucket}/{remote_path}"
|
||||
except ClientError as e:
|
||||
raise Exception(f"S3 upload failed: {e}")
|
||||
|
||||
def download_file(self, remote_path: str, local_path: str) -> None:
|
||||
"""Download file from S3."""
|
||||
local_file = Path(local_path)
|
||||
local_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
self.s3_client.download_file(
|
||||
self.bucket,
|
||||
remote_path,
|
||||
str(local_file)
|
||||
)
|
||||
except ClientError as e:
|
||||
if e.response['Error']['Code'] == '404':
|
||||
raise FileNotFoundError(f"Remote file not found: {remote_path}")
|
||||
raise Exception(f"S3 download failed: {e}")
|
||||
|
||||
def delete_file(self, remote_path: str) -> None:
|
||||
"""Delete file from S3."""
|
||||
try:
|
||||
self.s3_client.delete_object(
|
||||
Bucket=self.bucket,
|
||||
Key=remote_path
|
||||
)
|
||||
except ClientError as e:
|
||||
raise Exception(f"S3 deletion failed: {e}")
|
||||
|
||||
def list_files(
|
||||
self, prefix: str = "", max_results: int = 1000
|
||||
) -> List[StorageObject]:
|
||||
"""List files in S3 bucket."""
|
||||
try:
|
||||
paginator = self.s3_client.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(
|
||||
Bucket=self.bucket,
|
||||
Prefix=prefix,
|
||||
PaginationConfig={'MaxItems': max_results}
|
||||
)
|
||||
|
||||
files = []
|
||||
for page in page_iterator:
|
||||
if 'Contents' not in page:
|
||||
continue
|
||||
|
||||
for obj in page['Contents']:
|
||||
files.append(StorageObject(
|
||||
key=obj['Key'],
|
||||
size=obj['Size'],
|
||||
last_modified=obj['LastModified'].isoformat(),
|
||||
etag=obj.get('ETag', '').strip('"')
|
||||
))
|
||||
|
||||
return files
|
||||
except ClientError as e:
|
||||
raise Exception(f"S3 listing failed: {e}")
|
||||
|
||||
def file_exists(self, remote_path: str) -> bool:
|
||||
"""Check if file exists in S3."""
|
||||
try:
|
||||
self.s3_client.head_object(
|
||||
Bucket=self.bucket,
|
||||
Key=remote_path
|
||||
)
|
||||
return True
|
||||
except ClientError as e:
|
||||
if e.response['Error']['Code'] == '404':
|
||||
return False
|
||||
raise Exception(f"S3 head_object failed: {e}")
|
||||
|
||||
def get_file_url(self, remote_path: str, expires_in: int = 3600) -> str:
|
||||
"""Generate presigned URL for S3 object."""
|
||||
try:
|
||||
url = self.s3_client.generate_presigned_url(
|
||||
'get_object',
|
||||
Params={
|
||||
'Bucket': self.bucket,
|
||||
'Key': remote_path
|
||||
},
|
||||
ExpiresIn=expires_in
|
||||
)
|
||||
return url
|
||||
except ClientError as e:
|
||||
raise Exception(f"S3 presigned URL generation failed: {e}")
|
||||
|
||||
def copy_file(self, source_path: str, dest_path: str) -> None:
|
||||
"""Copy file within S3 bucket (server-side copy)."""
|
||||
try:
|
||||
copy_source = {
|
||||
'Bucket': self.bucket,
|
||||
'Key': source_path
|
||||
}
|
||||
self.s3_client.copy_object(
|
||||
CopySource=copy_source,
|
||||
Bucket=self.bucket,
|
||||
Key=dest_path
|
||||
)
|
||||
except ClientError as e:
|
||||
if e.response['Error']['Code'] == '404':
|
||||
raise FileNotFoundError(f"Source file not found: {source_path}")
|
||||
raise Exception(f"S3 copy failed: {e}")
|
||||
224
src/skill_seekers/cli/sync_cli.py
Normal file
224
src/skill_seekers/cli/sync_cli.py
Normal file
@@ -0,0 +1,224 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Documentation sync CLI.
|
||||
|
||||
Monitor documentation for changes and automatically update skills.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import signal
|
||||
from pathlib import Path
|
||||
|
||||
from ..sync import SyncMonitor
|
||||
|
||||
|
||||
def handle_signal(signum, frame):
|
||||
"""Handle interrupt signals."""
|
||||
print("\n🛑 Stopping sync monitor...")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def start_command(args):
|
||||
"""Start monitoring."""
|
||||
monitor = SyncMonitor(
|
||||
config_path=args.config,
|
||||
check_interval=args.interval,
|
||||
auto_update=args.auto_update
|
||||
)
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, handle_signal)
|
||||
signal.signal(signal.SIGTERM, handle_signal)
|
||||
|
||||
try:
|
||||
monitor.start()
|
||||
|
||||
print(f"\n📊 Monitoring {args.config}")
|
||||
print(f" Check interval: {args.interval}s ({args.interval // 60}m)")
|
||||
print(f" Auto-update: {'✅ enabled' if args.auto_update else '❌ disabled'}")
|
||||
print("\nPress Ctrl+C to stop\n")
|
||||
|
||||
# Keep running
|
||||
while True:
|
||||
import time
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n🛑 Stopping...")
|
||||
monitor.stop()
|
||||
|
||||
|
||||
def check_command(args):
|
||||
"""Check for changes once."""
|
||||
monitor = SyncMonitor(
|
||||
config_path=args.config,
|
||||
check_interval=3600 # Not used for single check
|
||||
)
|
||||
|
||||
print(f"🔍 Checking {args.config} for changes...")
|
||||
|
||||
report = monitor.check_now(generate_diffs=args.diff)
|
||||
|
||||
print(f"\n📊 Results:")
|
||||
print(f" Total pages: {report.total_pages}")
|
||||
print(f" Added: {len(report.added)}")
|
||||
print(f" Modified: {len(report.modified)}")
|
||||
print(f" Deleted: {len(report.deleted)}")
|
||||
print(f" Unchanged: {report.unchanged}")
|
||||
|
||||
if report.has_changes:
|
||||
print(f"\n✨ Detected {report.change_count} changes!")
|
||||
|
||||
if args.verbose:
|
||||
if report.added:
|
||||
print("\n✅ Added pages:")
|
||||
for change in report.added:
|
||||
print(f" • {change.url}")
|
||||
|
||||
if report.modified:
|
||||
print("\n✏️ Modified pages:")
|
||||
for change in report.modified:
|
||||
print(f" • {change.url}")
|
||||
if change.diff and args.diff:
|
||||
print(f" Diff preview (first 5 lines):")
|
||||
for line in change.diff.split('\n')[:5]:
|
||||
print(f" {line}")
|
||||
|
||||
if report.deleted:
|
||||
print("\n❌ Deleted pages:")
|
||||
for change in report.deleted:
|
||||
print(f" • {change.url}")
|
||||
else:
|
||||
print("\n✅ No changes detected")
|
||||
|
||||
|
||||
def stats_command(args):
|
||||
"""Show monitoring statistics."""
|
||||
monitor = SyncMonitor(
|
||||
config_path=args.config,
|
||||
check_interval=3600
|
||||
)
|
||||
|
||||
stats = monitor.stats()
|
||||
|
||||
print(f"\n📊 Statistics for {stats['skill_name']}:")
|
||||
print(f" Status: {stats['status']}")
|
||||
print(f" Last check: {stats['last_check'] or 'Never'}")
|
||||
print(f" Last change: {stats['last_change'] or 'Never'}")
|
||||
print(f" Total checks: {stats['total_checks']}")
|
||||
print(f" Total changes: {stats['total_changes']}")
|
||||
print(f" Tracked pages: {stats['tracked_pages']}")
|
||||
print(f" Running: {'✅ Yes' if stats['running'] else '❌ No'}")
|
||||
|
||||
|
||||
def reset_command(args):
|
||||
"""Reset monitoring state."""
|
||||
state_file = Path(f"{args.skill_name}_sync.json")
|
||||
|
||||
if state_file.exists():
|
||||
if args.force or input(f"⚠️ Reset state for {args.skill_name}? [y/N]: ").lower() == 'y':
|
||||
state_file.unlink()
|
||||
print(f"✅ State reset for {args.skill_name}")
|
||||
else:
|
||||
print("❌ Reset cancelled")
|
||||
else:
|
||||
print(f"ℹ️ No state file found for {args.skill_name}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Monitor documentation for changes and update skills',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Start monitoring (checks every hour)
|
||||
skill-seekers-sync start --config configs/react.json
|
||||
|
||||
# Start with custom interval (10 minutes)
|
||||
skill-seekers-sync start --config configs/react.json --interval 600
|
||||
|
||||
# Start with auto-update
|
||||
skill-seekers-sync start --config configs/react.json --auto-update
|
||||
|
||||
# Check once (no continuous monitoring)
|
||||
skill-seekers-sync check --config configs/react.json
|
||||
|
||||
# Check with diffs
|
||||
skill-seekers-sync check --config configs/react.json --diff -v
|
||||
|
||||
# Show statistics
|
||||
skill-seekers-sync stats --config configs/react.json
|
||||
|
||||
# Reset state
|
||||
skill-seekers-sync reset --skill-name react
|
||||
"""
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest='command', help='Command to execute')
|
||||
|
||||
# Start command
|
||||
start_parser = subparsers.add_parser('start', help='Start continuous monitoring')
|
||||
start_parser.add_argument('--config', required=True, help='Path to skill config file')
|
||||
start_parser.add_argument(
|
||||
'--interval', '-i',
|
||||
type=int,
|
||||
default=3600,
|
||||
help='Check interval in seconds (default: 3600 = 1 hour)'
|
||||
)
|
||||
start_parser.add_argument(
|
||||
'--auto-update',
|
||||
action='store_true',
|
||||
help='Automatically rebuild skill on changes'
|
||||
)
|
||||
|
||||
# Check command
|
||||
check_parser = subparsers.add_parser('check', help='Check for changes once')
|
||||
check_parser.add_argument('--config', required=True, help='Path to skill config file')
|
||||
check_parser.add_argument(
|
||||
'--diff', '-d',
|
||||
action='store_true',
|
||||
help='Generate content diffs'
|
||||
)
|
||||
check_parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Show detailed output'
|
||||
)
|
||||
|
||||
# Stats command
|
||||
stats_parser = subparsers.add_parser('stats', help='Show monitoring statistics')
|
||||
stats_parser.add_argument('--config', required=True, help='Path to skill config file')
|
||||
|
||||
# Reset command
|
||||
reset_parser = subparsers.add_parser('reset', help='Reset monitoring state')
|
||||
reset_parser.add_argument('--skill-name', required=True, help='Skill name')
|
||||
reset_parser.add_argument(
|
||||
'--force', '-f',
|
||||
action='store_true',
|
||||
help='Skip confirmation'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
if args.command == 'start':
|
||||
start_command(args)
|
||||
elif args.command == 'check':
|
||||
check_command(args)
|
||||
elif args.command == 'stats':
|
||||
stats_command(args)
|
||||
elif args.command == 'reset':
|
||||
reset_command(args)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
31
src/skill_seekers/embedding/__init__.py
Normal file
31
src/skill_seekers/embedding/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Embedding generation system for Skill Seekers.
|
||||
|
||||
Provides:
|
||||
- FastAPI server for embedding generation
|
||||
- Multiple embedding model support (OpenAI, sentence-transformers, Anthropic)
|
||||
- Batch processing for efficiency
|
||||
- Caching layer for embeddings
|
||||
- Vector database integration
|
||||
|
||||
Usage:
|
||||
# Start server
|
||||
python -m skill_seekers.embedding.server
|
||||
|
||||
# Generate embeddings
|
||||
curl -X POST http://localhost:8000/embed \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"texts": ["Hello world"], "model": "text-embedding-3-small"}'
|
||||
"""
|
||||
|
||||
from .models import EmbeddingRequest, EmbeddingResponse, BatchEmbeddingRequest
|
||||
from .generator import EmbeddingGenerator
|
||||
from .cache import EmbeddingCache
|
||||
|
||||
__all__ = [
|
||||
'EmbeddingRequest',
|
||||
'EmbeddingResponse',
|
||||
'BatchEmbeddingRequest',
|
||||
'EmbeddingGenerator',
|
||||
'EmbeddingCache',
|
||||
]
|
||||
335
src/skill_seekers/embedding/cache.py
Normal file
335
src/skill_seekers/embedding/cache.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Caching layer for embeddings.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""
|
||||
SQLite-based cache for embeddings.
|
||||
|
||||
Stores embeddings with their text hashes to avoid regeneration.
|
||||
Supports TTL (time-to-live) for cache entries.
|
||||
|
||||
Examples:
|
||||
cache = EmbeddingCache("/path/to/cache.db")
|
||||
|
||||
# Store embedding
|
||||
cache.set("hash123", [0.1, 0.2, 0.3], model="text-embedding-3-small")
|
||||
|
||||
# Retrieve embedding
|
||||
embedding = cache.get("hash123")
|
||||
|
||||
# Check if cached
|
||||
if cache.has("hash123"):
|
||||
print("Embedding is cached")
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = ":memory:", ttl_days: int = 30):
|
||||
"""
|
||||
Initialize embedding cache.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database (":memory:" for in-memory)
|
||||
ttl_days: Time-to-live for cache entries in days
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.ttl_days = ttl_days
|
||||
|
||||
# Create database directory if needed
|
||||
if db_path != ":memory:":
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize database
|
||||
self.conn = sqlite3.connect(db_path, check_same_thread=False)
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database schema."""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS embeddings (
|
||||
hash TEXT PRIMARY KEY,
|
||||
embedding TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
dimensions INTEGER NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
accessed_at TEXT NOT NULL,
|
||||
access_count INTEGER DEFAULT 1
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_model ON embeddings(model)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_created_at ON embeddings(created_at)
|
||||
""")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def set(
|
||||
self,
|
||||
hash_key: str,
|
||||
embedding: List[float],
|
||||
model: str
|
||||
) -> None:
|
||||
"""
|
||||
Store embedding in cache.
|
||||
|
||||
Args:
|
||||
hash_key: Hash of text+model
|
||||
embedding: Embedding vector
|
||||
model: Model name
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
now = datetime.utcnow().isoformat()
|
||||
embedding_json = json.dumps(embedding)
|
||||
dimensions = len(embedding)
|
||||
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO embeddings
|
||||
(hash, embedding, model, dimensions, created_at, accessed_at, access_count)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 1)
|
||||
""", (hash_key, embedding_json, model, dimensions, now, now))
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def get(self, hash_key: str) -> Optional[List[float]]:
|
||||
"""
|
||||
Retrieve embedding from cache.
|
||||
|
||||
Args:
|
||||
hash_key: Hash of text+model
|
||||
|
||||
Returns:
|
||||
Embedding vector if cached and not expired, None otherwise
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# Get embedding
|
||||
cursor.execute("""
|
||||
SELECT embedding, created_at
|
||||
FROM embeddings
|
||||
WHERE hash = ?
|
||||
""", (hash_key,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
embedding_json, created_at = row
|
||||
|
||||
# Check TTL
|
||||
created = datetime.fromisoformat(created_at)
|
||||
if datetime.utcnow() - created > timedelta(days=self.ttl_days):
|
||||
# Expired, delete and return None
|
||||
self.delete(hash_key)
|
||||
return None
|
||||
|
||||
# Update access stats
|
||||
now = datetime.utcnow().isoformat()
|
||||
cursor.execute("""
|
||||
UPDATE embeddings
|
||||
SET accessed_at = ?, access_count = access_count + 1
|
||||
WHERE hash = ?
|
||||
""", (now, hash_key))
|
||||
self.conn.commit()
|
||||
|
||||
return json.loads(embedding_json)
|
||||
|
||||
def get_batch(self, hash_keys: List[str]) -> Tuple[List[Optional[List[float]]], List[bool]]:
|
||||
"""
|
||||
Retrieve multiple embeddings from cache.
|
||||
|
||||
Args:
|
||||
hash_keys: List of hashes
|
||||
|
||||
Returns:
|
||||
Tuple of (embeddings list, cached flags)
|
||||
embeddings list contains None for cache misses
|
||||
"""
|
||||
embeddings = []
|
||||
cached_flags = []
|
||||
|
||||
for hash_key in hash_keys:
|
||||
embedding = self.get(hash_key)
|
||||
embeddings.append(embedding)
|
||||
cached_flags.append(embedding is not None)
|
||||
|
||||
return embeddings, cached_flags
|
||||
|
||||
def has(self, hash_key: str) -> bool:
|
||||
"""
|
||||
Check if embedding is cached and not expired.
|
||||
|
||||
Args:
|
||||
hash_key: Hash of text+model
|
||||
|
||||
Returns:
|
||||
True if cached and not expired, False otherwise
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT created_at
|
||||
FROM embeddings
|
||||
WHERE hash = ?
|
||||
""", (hash_key,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return False
|
||||
|
||||
# Check TTL
|
||||
created = datetime.fromisoformat(row[0])
|
||||
if datetime.utcnow() - created > timedelta(days=self.ttl_days):
|
||||
# Expired
|
||||
self.delete(hash_key)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def delete(self, hash_key: str) -> None:
|
||||
"""
|
||||
Delete embedding from cache.
|
||||
|
||||
Args:
|
||||
hash_key: Hash of text+model
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
DELETE FROM embeddings
|
||||
WHERE hash = ?
|
||||
""", (hash_key,))
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def clear(self, model: Optional[str] = None) -> int:
|
||||
"""
|
||||
Clear cache entries.
|
||||
|
||||
Args:
|
||||
model: If provided, only clear entries for this model
|
||||
|
||||
Returns:
|
||||
Number of entries deleted
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
if model:
|
||||
cursor.execute("""
|
||||
DELETE FROM embeddings
|
||||
WHERE model = ?
|
||||
""", (model,))
|
||||
else:
|
||||
cursor.execute("DELETE FROM embeddings")
|
||||
|
||||
deleted = cursor.rowcount
|
||||
self.conn.commit()
|
||||
|
||||
return deleted
|
||||
|
||||
def clear_expired(self) -> int:
|
||||
"""
|
||||
Clear expired cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries deleted
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cutoff = (datetime.utcnow() - timedelta(days=self.ttl_days)).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
DELETE FROM embeddings
|
||||
WHERE created_at < ?
|
||||
""", (cutoff,))
|
||||
|
||||
deleted = cursor.rowcount
|
||||
self.conn.commit()
|
||||
|
||||
return deleted
|
||||
|
||||
def size(self) -> int:
|
||||
"""
|
||||
Get number of cached embeddings.
|
||||
|
||||
Returns:
|
||||
Number of cache entries
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM embeddings")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache stats
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# Total entries
|
||||
cursor.execute("SELECT COUNT(*) FROM embeddings")
|
||||
total = cursor.fetchone()[0]
|
||||
|
||||
# Entries by model
|
||||
cursor.execute("""
|
||||
SELECT model, COUNT(*)
|
||||
FROM embeddings
|
||||
GROUP BY model
|
||||
""")
|
||||
by_model = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
# Most accessed
|
||||
cursor.execute("""
|
||||
SELECT hash, model, access_count
|
||||
FROM embeddings
|
||||
ORDER BY access_count DESC
|
||||
LIMIT 10
|
||||
""")
|
||||
top_accessed = [
|
||||
{"hash": row[0], "model": row[1], "access_count": row[2]}
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
# Expired entries
|
||||
cutoff = (datetime.utcnow() - timedelta(days=self.ttl_days)).isoformat()
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*)
|
||||
FROM embeddings
|
||||
WHERE created_at < ?
|
||||
""", (cutoff,))
|
||||
expired = cursor.fetchone()[0]
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"by_model": by_model,
|
||||
"top_accessed": top_accessed,
|
||||
"expired": expired,
|
||||
"ttl_days": self.ttl_days
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""Close database connection."""
|
||||
self.conn.close()
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
443
src/skill_seekers/embedding/generator.py
Normal file
443
src/skill_seekers/embedding/generator.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
Embedding generation with multiple model support.
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
# OpenAI support
|
||||
try:
|
||||
from openai import OpenAI
|
||||
OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
OPENAI_AVAILABLE = False
|
||||
|
||||
# Sentence transformers support
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
SENTENCE_TRANSFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SENTENCE_TRANSFORMERS_AVAILABLE = False
|
||||
|
||||
# Voyage AI support (recommended by Anthropic for embeddings)
|
||||
try:
|
||||
import voyageai
|
||||
VOYAGE_AVAILABLE = True
|
||||
except ImportError:
|
||||
VOYAGE_AVAILABLE = False
|
||||
|
||||
|
||||
class EmbeddingGenerator:
|
||||
"""
|
||||
Generate embeddings using multiple model providers.
|
||||
|
||||
Supported providers:
|
||||
- OpenAI (text-embedding-3-small, text-embedding-3-large, text-embedding-ada-002)
|
||||
- Sentence Transformers (all-MiniLM-L6-v2, all-mpnet-base-v2, etc.)
|
||||
- Anthropic/Voyage AI (voyage-2, voyage-large-2)
|
||||
|
||||
Examples:
|
||||
# OpenAI embeddings
|
||||
generator = EmbeddingGenerator()
|
||||
embedding = generator.generate("Hello world", model="text-embedding-3-small")
|
||||
|
||||
# Sentence transformers (local, no API)
|
||||
embedding = generator.generate("Hello world", model="all-MiniLM-L6-v2")
|
||||
|
||||
# Batch generation
|
||||
embeddings = generator.generate_batch(
|
||||
["text1", "text2", "text3"],
|
||||
model="text-embedding-3-small"
|
||||
)
|
||||
"""
|
||||
|
||||
# Model configurations
|
||||
MODELS = {
|
||||
# OpenAI models
|
||||
"text-embedding-3-small": {
|
||||
"provider": "openai",
|
||||
"dimensions": 1536,
|
||||
"max_tokens": 8191,
|
||||
"cost_per_million": 0.02,
|
||||
},
|
||||
"text-embedding-3-large": {
|
||||
"provider": "openai",
|
||||
"dimensions": 3072,
|
||||
"max_tokens": 8191,
|
||||
"cost_per_million": 0.13,
|
||||
},
|
||||
"text-embedding-ada-002": {
|
||||
"provider": "openai",
|
||||
"dimensions": 1536,
|
||||
"max_tokens": 8191,
|
||||
"cost_per_million": 0.10,
|
||||
},
|
||||
# Voyage AI models (recommended by Anthropic)
|
||||
"voyage-3": {
|
||||
"provider": "voyage",
|
||||
"dimensions": 1024,
|
||||
"max_tokens": 32000,
|
||||
"cost_per_million": 0.06,
|
||||
},
|
||||
"voyage-3-lite": {
|
||||
"provider": "voyage",
|
||||
"dimensions": 512,
|
||||
"max_tokens": 32000,
|
||||
"cost_per_million": 0.06,
|
||||
},
|
||||
"voyage-large-2": {
|
||||
"provider": "voyage",
|
||||
"dimensions": 1536,
|
||||
"max_tokens": 16000,
|
||||
"cost_per_million": 0.12,
|
||||
},
|
||||
"voyage-code-2": {
|
||||
"provider": "voyage",
|
||||
"dimensions": 1536,
|
||||
"max_tokens": 16000,
|
||||
"cost_per_million": 0.12,
|
||||
},
|
||||
"voyage-2": {
|
||||
"provider": "voyage",
|
||||
"dimensions": 1024,
|
||||
"max_tokens": 4000,
|
||||
"cost_per_million": 0.10,
|
||||
},
|
||||
# Sentence transformer models (local, free)
|
||||
"all-MiniLM-L6-v2": {
|
||||
"provider": "sentence-transformers",
|
||||
"dimensions": 384,
|
||||
"max_tokens": 256,
|
||||
"cost_per_million": 0.0,
|
||||
},
|
||||
"all-mpnet-base-v2": {
|
||||
"provider": "sentence-transformers",
|
||||
"dimensions": 768,
|
||||
"max_tokens": 384,
|
||||
"cost_per_million": 0.0,
|
||||
},
|
||||
"paraphrase-MiniLM-L6-v2": {
|
||||
"provider": "sentence-transformers",
|
||||
"dimensions": 384,
|
||||
"max_tokens": 128,
|
||||
"cost_per_million": 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
voyage_api_key: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize embedding generator.
|
||||
|
||||
Args:
|
||||
api_key: API key for OpenAI
|
||||
voyage_api_key: API key for Voyage AI (Anthropic's recommended embeddings)
|
||||
cache_dir: Directory for caching models (sentence-transformers)
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.voyage_api_key = voyage_api_key or os.getenv("VOYAGE_API_KEY")
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
# Initialize OpenAI client
|
||||
if OPENAI_AVAILABLE and self.api_key:
|
||||
self.openai_client = OpenAI(api_key=self.api_key)
|
||||
else:
|
||||
self.openai_client = None
|
||||
|
||||
# Initialize Voyage AI client
|
||||
if VOYAGE_AVAILABLE and self.voyage_api_key:
|
||||
self.voyage_client = voyageai.Client(api_key=self.voyage_api_key)
|
||||
else:
|
||||
self.voyage_client = None
|
||||
|
||||
# Cache for sentence transformer models
|
||||
self._st_models = {}
|
||||
|
||||
def get_model_info(self, model: str) -> dict:
|
||||
"""Get information about a model."""
|
||||
if model not in self.MODELS:
|
||||
raise ValueError(
|
||||
f"Unknown model: {model}. "
|
||||
f"Available models: {', '.join(self.MODELS.keys())}"
|
||||
)
|
||||
return self.MODELS[model]
|
||||
|
||||
def list_models(self) -> List[dict]:
|
||||
"""List all available models."""
|
||||
models = []
|
||||
for name, info in self.MODELS.items():
|
||||
models.append({
|
||||
"name": name,
|
||||
"provider": info["provider"],
|
||||
"dimensions": info["dimensions"],
|
||||
"max_tokens": info["max_tokens"],
|
||||
"cost_per_million": info.get("cost_per_million", 0.0),
|
||||
})
|
||||
return models
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text: str,
|
||||
model: str = "text-embedding-3-small",
|
||||
normalize: bool = True
|
||||
) -> List[float]:
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
model: Model name
|
||||
normalize: Whether to normalize to unit length
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not supported
|
||||
Exception: If embedding generation fails
|
||||
"""
|
||||
model_info = self.get_model_info(model)
|
||||
provider = model_info["provider"]
|
||||
|
||||
if provider == "openai":
|
||||
return self._generate_openai(text, model, normalize)
|
||||
elif provider == "voyage":
|
||||
return self._generate_voyage(text, model, normalize)
|
||||
elif provider == "sentence-transformers":
|
||||
return self._generate_sentence_transformer(text, model, normalize)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
def generate_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
model: str = "text-embedding-3-small",
|
||||
normalize: bool = True,
|
||||
batch_size: int = 32
|
||||
) -> Tuple[List[List[float]], int]:
|
||||
"""
|
||||
Generate embeddings for multiple texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
model: Model name
|
||||
normalize: Whether to normalize to unit length
|
||||
batch_size: Batch size for processing
|
||||
|
||||
Returns:
|
||||
Tuple of (embeddings list, dimensions)
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not supported
|
||||
Exception: If embedding generation fails
|
||||
"""
|
||||
model_info = self.get_model_info(model)
|
||||
provider = model_info["provider"]
|
||||
|
||||
if provider == "openai":
|
||||
return self._generate_openai_batch(texts, model, normalize, batch_size)
|
||||
elif provider == "voyage":
|
||||
return self._generate_voyage_batch(texts, model, normalize, batch_size)
|
||||
elif provider == "sentence-transformers":
|
||||
return self._generate_sentence_transformer_batch(texts, model, normalize, batch_size)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
def _generate_openai(
|
||||
self, text: str, model: str, normalize: bool
|
||||
) -> List[float]:
|
||||
"""Generate embedding using OpenAI API."""
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenAI is required for OpenAI embeddings. "
|
||||
"Install with: pip install openai"
|
||||
)
|
||||
|
||||
if not self.openai_client:
|
||||
raise ValueError("OpenAI API key not provided")
|
||||
|
||||
try:
|
||||
response = self.openai_client.embeddings.create(
|
||||
input=text,
|
||||
model=model
|
||||
)
|
||||
embedding = response.data[0].embedding
|
||||
|
||||
if normalize:
|
||||
embedding = self._normalize(embedding)
|
||||
|
||||
return embedding
|
||||
except Exception as e:
|
||||
raise Exception(f"OpenAI embedding generation failed: {e}")
|
||||
|
||||
def _generate_openai_batch(
|
||||
self, texts: List[str], model: str, normalize: bool, batch_size: int
|
||||
) -> Tuple[List[List[float]], int]:
|
||||
"""Generate embeddings using OpenAI API in batches."""
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenAI is required for OpenAI embeddings. "
|
||||
"Install with: pip install openai"
|
||||
)
|
||||
|
||||
if not self.openai_client:
|
||||
raise ValueError("OpenAI API key not provided")
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
|
||||
try:
|
||||
response = self.openai_client.embeddings.create(
|
||||
input=batch,
|
||||
model=model
|
||||
)
|
||||
|
||||
batch_embeddings = [item.embedding for item in response.data]
|
||||
|
||||
if normalize:
|
||||
batch_embeddings = [self._normalize(emb) for emb in batch_embeddings]
|
||||
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"OpenAI batch embedding generation failed: {e}")
|
||||
|
||||
dimensions = len(all_embeddings[0]) if all_embeddings else 0
|
||||
return all_embeddings, dimensions
|
||||
|
||||
def _generate_voyage(
|
||||
self, text: str, model: str, normalize: bool
|
||||
) -> List[float]:
|
||||
"""Generate embedding using Voyage AI API."""
|
||||
if not VOYAGE_AVAILABLE:
|
||||
raise ImportError(
|
||||
"voyageai is required for Voyage AI embeddings. "
|
||||
"Install with: pip install voyageai"
|
||||
)
|
||||
|
||||
if not self.voyage_client:
|
||||
raise ValueError("Voyage API key not provided")
|
||||
|
||||
try:
|
||||
result = self.voyage_client.embed(
|
||||
texts=[text],
|
||||
model=model
|
||||
)
|
||||
embedding = result.embeddings[0]
|
||||
|
||||
if normalize:
|
||||
embedding = self._normalize(embedding)
|
||||
|
||||
return embedding
|
||||
except Exception as e:
|
||||
raise Exception(f"Voyage AI embedding generation failed: {e}")
|
||||
|
||||
def _generate_voyage_batch(
|
||||
self, texts: List[str], model: str, normalize: bool, batch_size: int
|
||||
) -> Tuple[List[List[float]], int]:
|
||||
"""Generate embeddings using Voyage AI API in batches."""
|
||||
if not VOYAGE_AVAILABLE:
|
||||
raise ImportError(
|
||||
"voyageai is required for Voyage AI embeddings. "
|
||||
"Install with: pip install voyageai"
|
||||
)
|
||||
|
||||
if not self.voyage_client:
|
||||
raise ValueError("Voyage API key not provided")
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
# Process in batches (Voyage AI supports up to 128 texts per request)
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
|
||||
try:
|
||||
result = self.voyage_client.embed(
|
||||
texts=batch,
|
||||
model=model
|
||||
)
|
||||
|
||||
batch_embeddings = result.embeddings
|
||||
|
||||
if normalize:
|
||||
batch_embeddings = [self._normalize(emb) for emb in batch_embeddings]
|
||||
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Voyage AI batch embedding generation failed: {e}")
|
||||
|
||||
dimensions = len(all_embeddings[0]) if all_embeddings else 0
|
||||
return all_embeddings, dimensions
|
||||
|
||||
def _generate_sentence_transformer(
|
||||
self, text: str, model: str, normalize: bool
|
||||
) -> List[float]:
|
||||
"""Generate embedding using sentence-transformers."""
|
||||
if not SENTENCE_TRANSFORMERS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"sentence-transformers is required for local embeddings. "
|
||||
"Install with: pip install sentence-transformers"
|
||||
)
|
||||
|
||||
# Load model (with caching)
|
||||
if model not in self._st_models:
|
||||
self._st_models[model] = SentenceTransformer(model, cache_folder=self.cache_dir)
|
||||
|
||||
st_model = self._st_models[model]
|
||||
|
||||
# Generate embedding
|
||||
embedding = st_model.encode(text, normalize_embeddings=normalize)
|
||||
|
||||
return embedding.tolist()
|
||||
|
||||
def _generate_sentence_transformer_batch(
|
||||
self, texts: List[str], model: str, normalize: bool, batch_size: int
|
||||
) -> Tuple[List[List[float]], int]:
|
||||
"""Generate embeddings using sentence-transformers in batches."""
|
||||
if not SENTENCE_TRANSFORMERS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"sentence-transformers is required for local embeddings. "
|
||||
"Install with: pip install sentence-transformers"
|
||||
)
|
||||
|
||||
# Load model (with caching)
|
||||
if model not in self._st_models:
|
||||
self._st_models[model] = SentenceTransformer(model, cache_folder=self.cache_dir)
|
||||
|
||||
st_model = self._st_models[model]
|
||||
|
||||
# Generate embeddings in batches
|
||||
embeddings = st_model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
normalize_embeddings=normalize,
|
||||
show_progress_bar=False
|
||||
)
|
||||
|
||||
dimensions = len(embeddings[0]) if len(embeddings) > 0 else 0
|
||||
return embeddings.tolist(), dimensions
|
||||
|
||||
@staticmethod
|
||||
def _normalize(embedding: List[float]) -> List[float]:
|
||||
"""Normalize embedding to unit length."""
|
||||
vec = np.array(embedding)
|
||||
norm = np.linalg.norm(vec)
|
||||
if norm > 0:
|
||||
vec = vec / norm
|
||||
return vec.tolist()
|
||||
|
||||
@staticmethod
|
||||
def compute_hash(text: str, model: str) -> str:
|
||||
"""Compute cache key for text and model."""
|
||||
content = f"{model}:{text}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
157
src/skill_seekers/embedding/models.py
Normal file
157
src/skill_seekers/embedding/models.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Pydantic models for embedding API.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
"""Request model for single embedding generation."""
|
||||
|
||||
text: str = Field(..., description="Text to generate embedding for")
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Embedding model to use"
|
||||
)
|
||||
normalize: bool = Field(
|
||||
default=True,
|
||||
description="Normalize embeddings to unit length"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"text": "This is a test document about Python programming.",
|
||||
"model": "text-embedding-3-small",
|
||||
"normalize": True
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BatchEmbeddingRequest(BaseModel):
|
||||
"""Request model for batch embedding generation."""
|
||||
|
||||
texts: List[str] = Field(..., description="List of texts to embed")
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Embedding model to use"
|
||||
)
|
||||
normalize: bool = Field(
|
||||
default=True,
|
||||
description="Normalize embeddings to unit length"
|
||||
)
|
||||
batch_size: Optional[int] = Field(
|
||||
default=32,
|
||||
description="Batch size for processing (default: 32)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"texts": [
|
||||
"First document about Python",
|
||||
"Second document about JavaScript",
|
||||
"Third document about Rust"
|
||||
],
|
||||
"model": "text-embedding-3-small",
|
||||
"normalize": True,
|
||||
"batch_size": 32
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
"""Response model for embedding generation."""
|
||||
|
||||
embedding: List[float] = Field(..., description="Generated embedding vector")
|
||||
model: str = Field(..., description="Model used for generation")
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
cached: bool = Field(
|
||||
default=False,
|
||||
description="Whether embedding was retrieved from cache"
|
||||
)
|
||||
|
||||
|
||||
class BatchEmbeddingResponse(BaseModel):
|
||||
"""Response model for batch embedding generation."""
|
||||
|
||||
embeddings: List[List[float]] = Field(..., description="List of embedding vectors")
|
||||
model: str = Field(..., description="Model used for generation")
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
count: int = Field(..., description="Number of embeddings generated")
|
||||
cached_count: int = Field(
|
||||
default=0,
|
||||
description="Number of embeddings retrieved from cache"
|
||||
)
|
||||
|
||||
|
||||
class SkillEmbeddingRequest(BaseModel):
|
||||
"""Request model for skill content embedding."""
|
||||
|
||||
skill_path: str = Field(..., description="Path to skill directory")
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Embedding model to use"
|
||||
)
|
||||
chunk_size: int = Field(
|
||||
default=512,
|
||||
description="Chunk size for splitting documents (tokens)"
|
||||
)
|
||||
overlap: int = Field(
|
||||
default=50,
|
||||
description="Overlap between chunks (tokens)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"skill_path": "/path/to/skill/react",
|
||||
"model": "text-embedding-3-small",
|
||||
"chunk_size": 512,
|
||||
"overlap": 50
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class SkillEmbeddingResponse(BaseModel):
|
||||
"""Response model for skill content embedding."""
|
||||
|
||||
skill_name: str = Field(..., description="Name of the skill")
|
||||
total_chunks: int = Field(..., description="Total number of chunks embedded")
|
||||
model: str = Field(..., description="Model used for generation")
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Skill metadata"
|
||||
)
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
status: str = Field(..., description="Service status")
|
||||
version: str = Field(..., description="API version")
|
||||
models: List[str] = Field(..., description="Available embedding models")
|
||||
cache_enabled: bool = Field(..., description="Whether cache is enabled")
|
||||
cache_size: Optional[int] = Field(None, description="Number of cached embeddings")
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Information about an embedding model."""
|
||||
|
||||
name: str = Field(..., description="Model name")
|
||||
provider: str = Field(..., description="Model provider (openai, anthropic, sentence-transformers)")
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
max_tokens: int = Field(..., description="Maximum input tokens")
|
||||
cost_per_million: Optional[float] = Field(
|
||||
None,
|
||||
description="Cost per million tokens (if applicable)"
|
||||
)
|
||||
|
||||
|
||||
class ModelsResponse(BaseModel):
|
||||
"""Response model for listing available models."""
|
||||
|
||||
models: List[ModelInfo] = Field(..., description="List of available models")
|
||||
count: int = Field(..., description="Number of available models")
|
||||
362
src/skill_seekers/embedding/server.py
Normal file
362
src/skill_seekers/embedding/server.py
Normal file
@@ -0,0 +1,362 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FastAPI server for embedding generation.
|
||||
|
||||
Provides endpoints for:
|
||||
- Single and batch embedding generation
|
||||
- Skill content embedding
|
||||
- Model listing and information
|
||||
- Cache management
|
||||
- Health checks
|
||||
|
||||
Usage:
|
||||
# Start server
|
||||
python -m skill_seekers.embedding.server
|
||||
|
||||
# Or with uvicorn
|
||||
uvicorn skill_seekers.embedding.server:app --host 0.0.0.0 --port 8000
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
import uvicorn
|
||||
FASTAPI_AVAILABLE = True
|
||||
except ImportError:
|
||||
FASTAPI_AVAILABLE = False
|
||||
|
||||
from .models import (
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
BatchEmbeddingRequest,
|
||||
BatchEmbeddingResponse,
|
||||
SkillEmbeddingRequest,
|
||||
SkillEmbeddingResponse,
|
||||
HealthResponse,
|
||||
ModelInfo,
|
||||
ModelsResponse,
|
||||
)
|
||||
from .generator import EmbeddingGenerator
|
||||
from .cache import EmbeddingCache
|
||||
|
||||
|
||||
# Initialize FastAPI app
|
||||
if FASTAPI_AVAILABLE:
|
||||
app = FastAPI(
|
||||
title="Skill Seekers Embedding API",
|
||||
description="Generate embeddings for text and skill content",
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc"
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize generator and cache
|
||||
cache_dir = os.getenv("EMBEDDING_CACHE_DIR", os.path.expanduser("~/.cache/skill-seekers/embeddings"))
|
||||
cache_db = os.path.join(cache_dir, "embeddings.db")
|
||||
cache_enabled = os.getenv("EMBEDDING_CACHE_ENABLED", "true").lower() == "true"
|
||||
|
||||
generator = EmbeddingGenerator(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
voyage_api_key=os.getenv("VOYAGE_API_KEY")
|
||||
)
|
||||
cache = EmbeddingCache(cache_db) if cache_enabled else None
|
||||
|
||||
@app.get("/", response_model=dict)
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
return {
|
||||
"service": "Skill Seekers Embedding API",
|
||||
"version": "1.0.0",
|
||||
"docs": "/docs",
|
||||
"health": "/health"
|
||||
}
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
models = [m["name"] for m in generator.list_models()]
|
||||
cache_size = cache.size() if cache else None
|
||||
|
||||
return HealthResponse(
|
||||
status="ok",
|
||||
version="1.0.0",
|
||||
models=models,
|
||||
cache_enabled=cache_enabled,
|
||||
cache_size=cache_size
|
||||
)
|
||||
|
||||
@app.get("/models", response_model=ModelsResponse)
|
||||
async def list_models():
|
||||
"""List available embedding models."""
|
||||
models_list = generator.list_models()
|
||||
|
||||
model_infos = [
|
||||
ModelInfo(
|
||||
name=m["name"],
|
||||
provider=m["provider"],
|
||||
dimensions=m["dimensions"],
|
||||
max_tokens=m["max_tokens"],
|
||||
cost_per_million=m.get("cost_per_million")
|
||||
)
|
||||
for m in models_list
|
||||
]
|
||||
|
||||
return ModelsResponse(
|
||||
models=model_infos,
|
||||
count=len(model_infos)
|
||||
)
|
||||
|
||||
@app.post("/embed", response_model=EmbeddingResponse)
|
||||
async def embed_text(request: EmbeddingRequest):
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
request: Embedding request
|
||||
|
||||
Returns:
|
||||
Embedding response
|
||||
|
||||
Raises:
|
||||
HTTPException: If embedding generation fails
|
||||
"""
|
||||
try:
|
||||
# Check cache
|
||||
cached = False
|
||||
hash_key = generator.compute_hash(request.text, request.model)
|
||||
|
||||
if cache and cache.has(hash_key):
|
||||
embedding = cache.get(hash_key)
|
||||
cached = True
|
||||
else:
|
||||
# Generate embedding
|
||||
embedding = generator.generate(
|
||||
request.text,
|
||||
model=request.model,
|
||||
normalize=request.normalize
|
||||
)
|
||||
|
||||
# Store in cache
|
||||
if cache:
|
||||
cache.set(hash_key, embedding, request.model)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embedding=embedding,
|
||||
model=request.model,
|
||||
dimensions=len(embedding),
|
||||
cached=cached
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/embed/batch", response_model=BatchEmbeddingResponse)
|
||||
async def embed_batch(request: BatchEmbeddingRequest):
|
||||
"""
|
||||
Generate embeddings for multiple texts.
|
||||
|
||||
Args:
|
||||
request: Batch embedding request
|
||||
|
||||
Returns:
|
||||
Batch embedding response
|
||||
|
||||
Raises:
|
||||
HTTPException: If embedding generation fails
|
||||
"""
|
||||
try:
|
||||
# Check cache for each text
|
||||
cached_count = 0
|
||||
embeddings = []
|
||||
texts_to_generate = []
|
||||
text_indices = []
|
||||
|
||||
for idx, text in enumerate(request.texts):
|
||||
hash_key = generator.compute_hash(text, request.model)
|
||||
|
||||
if cache and cache.has(hash_key):
|
||||
cached_embedding = cache.get(hash_key)
|
||||
embeddings.append(cached_embedding)
|
||||
cached_count += 1
|
||||
else:
|
||||
embeddings.append(None) # Placeholder
|
||||
texts_to_generate.append(text)
|
||||
text_indices.append(idx)
|
||||
|
||||
# Generate embeddings for uncached texts
|
||||
if texts_to_generate:
|
||||
generated_embeddings, dimensions = generator.generate_batch(
|
||||
texts_to_generate,
|
||||
model=request.model,
|
||||
normalize=request.normalize,
|
||||
batch_size=request.batch_size
|
||||
)
|
||||
|
||||
# Fill in placeholders and cache
|
||||
for idx, text, embedding in zip(text_indices, texts_to_generate, generated_embeddings):
|
||||
embeddings[idx] = embedding
|
||||
|
||||
if cache:
|
||||
hash_key = generator.compute_hash(text, request.model)
|
||||
cache.set(hash_key, embedding, request.model)
|
||||
|
||||
dimensions = len(embeddings[0]) if embeddings else 0
|
||||
|
||||
return BatchEmbeddingResponse(
|
||||
embeddings=embeddings,
|
||||
model=request.model,
|
||||
dimensions=dimensions,
|
||||
count=len(embeddings),
|
||||
cached_count=cached_count
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/embed/skill", response_model=SkillEmbeddingResponse)
|
||||
async def embed_skill(request: SkillEmbeddingRequest):
|
||||
"""
|
||||
Generate embeddings for skill content.
|
||||
|
||||
Args:
|
||||
request: Skill embedding request
|
||||
|
||||
Returns:
|
||||
Skill embedding response
|
||||
|
||||
Raises:
|
||||
HTTPException: If skill embedding fails
|
||||
"""
|
||||
try:
|
||||
skill_path = Path(request.skill_path)
|
||||
|
||||
if not skill_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Skill path not found: {request.skill_path}")
|
||||
|
||||
# Read SKILL.md
|
||||
skill_md = skill_path / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
raise HTTPException(status_code=404, detail=f"SKILL.md not found in {request.skill_path}")
|
||||
|
||||
skill_content = skill_md.read_text()
|
||||
|
||||
# Simple chunking (split by double newline)
|
||||
chunks = [
|
||||
chunk.strip()
|
||||
for chunk in skill_content.split("\n\n")
|
||||
if chunk.strip() and len(chunk.strip()) > 50
|
||||
]
|
||||
|
||||
# Generate embeddings for chunks
|
||||
embeddings, dimensions = generator.generate_batch(
|
||||
chunks,
|
||||
model=request.model,
|
||||
normalize=True,
|
||||
batch_size=32
|
||||
)
|
||||
|
||||
# TODO: Store embeddings in vector database
|
||||
# This would integrate with the vector database adaptors
|
||||
|
||||
return SkillEmbeddingResponse(
|
||||
skill_name=skill_path.name,
|
||||
total_chunks=len(chunks),
|
||||
model=request.model,
|
||||
dimensions=dimensions,
|
||||
metadata={
|
||||
"skill_path": str(skill_path),
|
||||
"chunks": len(chunks),
|
||||
"content_length": len(skill_content)
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/cache/stats", response_model=dict)
|
||||
async def cache_stats():
|
||||
"""Get cache statistics."""
|
||||
if not cache:
|
||||
raise HTTPException(status_code=404, detail="Cache is disabled")
|
||||
|
||||
return cache.stats()
|
||||
|
||||
@app.post("/cache/clear", response_model=dict)
|
||||
async def clear_cache(
|
||||
model: Optional[str] = Query(None, description="Model to clear (all if not specified)")
|
||||
):
|
||||
"""Clear cache entries."""
|
||||
if not cache:
|
||||
raise HTTPException(status_code=404, detail="Cache is disabled")
|
||||
|
||||
deleted = cache.clear(model=model)
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"deleted": deleted,
|
||||
"model": model or "all"
|
||||
}
|
||||
|
||||
@app.post("/cache/clear-expired", response_model=dict)
|
||||
async def clear_expired():
|
||||
"""Clear expired cache entries."""
|
||||
if not cache:
|
||||
raise HTTPException(status_code=404, detail="Cache is disabled")
|
||||
|
||||
deleted = cache.clear_expired()
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"deleted": deleted
|
||||
}
|
||||
|
||||
else:
|
||||
print("Error: FastAPI not available. Install with: pip install fastapi uvicorn")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
if not FASTAPI_AVAILABLE:
|
||||
print("Error: FastAPI not available. Install with: pip install fastapi uvicorn")
|
||||
sys.exit(1)
|
||||
|
||||
# Get configuration from environment
|
||||
host = os.getenv("EMBEDDING_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("EMBEDDING_PORT", "8000"))
|
||||
reload = os.getenv("EMBEDDING_RELOAD", "false").lower() == "true"
|
||||
|
||||
print(f"🚀 Starting Embedding API server on {host}:{port}")
|
||||
print(f"📚 API documentation: http://{host}:{port}/docs")
|
||||
print(f"🔍 Cache enabled: {cache_enabled}")
|
||||
|
||||
if cache_enabled:
|
||||
print(f"💾 Cache database: {cache_db}")
|
||||
|
||||
uvicorn.run(
|
||||
"skill_seekers.embedding.server:app",
|
||||
host=host,
|
||||
port=port,
|
||||
reload=reload
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,19 +3,20 @@
|
||||
Skill Seeker MCP Server (FastMCP Implementation)
|
||||
|
||||
Modern, decorator-based MCP server using FastMCP for simplified tool registration.
|
||||
Provides 21 tools for generating Claude AI skills from documentation.
|
||||
Provides 25 tools for generating Claude AI skills from documentation.
|
||||
|
||||
This is a streamlined alternative to server.py (2200 lines → 708 lines, 68% reduction).
|
||||
All tool implementations are delegated to modular tool files in tools/ directory.
|
||||
|
||||
**Architecture:**
|
||||
- FastMCP server with decorator-based tool registration
|
||||
- 21 tools organized into 5 categories:
|
||||
- 25 tools organized into 6 categories:
|
||||
* Config tools (3): generate_config, list_configs, validate_config
|
||||
* Scraping tools (8): estimate_pages, scrape_docs, scrape_github, scrape_pdf, scrape_codebase, detect_patterns, extract_test_examples, build_how_to_guides, extract_config_patterns
|
||||
* Packaging tools (4): package_skill, upload_skill, enhance_skill, install_skill
|
||||
* Splitting tools (2): split_config, generate_router
|
||||
* Source tools (4): fetch_config, submit_config, add_config_source, list_config_sources, remove_config_source
|
||||
* Vector Database tools (4): export_to_weaviate, export_to_chroma, export_to_faiss, export_to_qdrant
|
||||
|
||||
**Usage:**
|
||||
# Stdio transport (default, backward compatible)
|
||||
@@ -75,6 +76,11 @@ try:
|
||||
enhance_skill_impl,
|
||||
# Scraping tools
|
||||
estimate_pages_impl,
|
||||
# Vector database tools
|
||||
export_to_chroma_impl,
|
||||
export_to_faiss_impl,
|
||||
export_to_qdrant_impl,
|
||||
export_to_weaviate_impl,
|
||||
extract_config_patterns_impl,
|
||||
extract_test_examples_impl,
|
||||
# Source tools
|
||||
@@ -109,6 +115,10 @@ except ImportError:
|
||||
detect_patterns_impl,
|
||||
enhance_skill_impl,
|
||||
estimate_pages_impl,
|
||||
export_to_chroma_impl,
|
||||
export_to_faiss_impl,
|
||||
export_to_qdrant_impl,
|
||||
export_to_weaviate_impl,
|
||||
extract_config_patterns_impl,
|
||||
extract_test_examples_impl,
|
||||
fetch_config_impl,
|
||||
@@ -1055,6 +1065,119 @@ async def remove_config_source(name: str) -> str:
|
||||
return str(result)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# VECTOR DATABASE TOOLS (4 tools)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@safe_tool_decorator(
|
||||
description="Export skill to Weaviate vector database format. Weaviate supports hybrid search (vector + BM25 keyword) with 450K+ users. Ideal for production RAG applications."
|
||||
)
|
||||
async def export_to_weaviate(
|
||||
skill_dir: str,
|
||||
output_dir: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Export skill to Weaviate vector database format.
|
||||
|
||||
Args:
|
||||
skill_dir: Path to skill directory (e.g., output/react/)
|
||||
output_dir: Output directory (default: same as skill_dir parent)
|
||||
|
||||
Returns:
|
||||
Export results with package path and usage instructions.
|
||||
"""
|
||||
args = {"skill_dir": skill_dir}
|
||||
if output_dir:
|
||||
args["output_dir"] = output_dir
|
||||
|
||||
result = await export_to_weaviate_impl(args)
|
||||
if isinstance(result, list) and result:
|
||||
return result[0].text if hasattr(result[0], "text") else str(result[0])
|
||||
return str(result)
|
||||
|
||||
|
||||
@safe_tool_decorator(
|
||||
description="Export skill to Chroma vector database format. Chroma is a popular open-source embedding database designed for local-first development with 800K+ developers."
|
||||
)
|
||||
async def export_to_chroma(
|
||||
skill_dir: str,
|
||||
output_dir: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Export skill to Chroma vector database format.
|
||||
|
||||
Args:
|
||||
skill_dir: Path to skill directory (e.g., output/react/)
|
||||
output_dir: Output directory (default: same as skill_dir parent)
|
||||
|
||||
Returns:
|
||||
Export results with package path and usage instructions.
|
||||
"""
|
||||
args = {"skill_dir": skill_dir}
|
||||
if output_dir:
|
||||
args["output_dir"] = output_dir
|
||||
|
||||
result = await export_to_chroma_impl(args)
|
||||
if isinstance(result, list) and result:
|
||||
return result[0].text if hasattr(result[0], "text") else str(result[0])
|
||||
return str(result)
|
||||
|
||||
|
||||
@safe_tool_decorator(
|
||||
description="Export skill to FAISS vector index format. FAISS (Facebook AI Similarity Search) supports billion-scale vector search with GPU acceleration."
|
||||
)
|
||||
async def export_to_faiss(
|
||||
skill_dir: str,
|
||||
output_dir: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Export skill to FAISS vector index format.
|
||||
|
||||
Args:
|
||||
skill_dir: Path to skill directory (e.g., output/react/)
|
||||
output_dir: Output directory (default: same as skill_dir parent)
|
||||
|
||||
Returns:
|
||||
Export results with package path and usage instructions.
|
||||
"""
|
||||
args = {"skill_dir": skill_dir}
|
||||
if output_dir:
|
||||
args["output_dir"] = output_dir
|
||||
|
||||
result = await export_to_faiss_impl(args)
|
||||
if isinstance(result, list) and result:
|
||||
return result[0].text if hasattr(result[0], "text") else str(result[0])
|
||||
return str(result)
|
||||
|
||||
|
||||
@safe_tool_decorator(
|
||||
description="Export skill to Qdrant vector database format. Qdrant is a modern vector database with native payload filtering and high-performance search, serving 100K+ users."
|
||||
)
|
||||
async def export_to_qdrant(
|
||||
skill_dir: str,
|
||||
output_dir: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Export skill to Qdrant vector database format.
|
||||
|
||||
Args:
|
||||
skill_dir: Path to skill directory (e.g., output/react/)
|
||||
output_dir: Output directory (default: same as skill_dir parent)
|
||||
|
||||
Returns:
|
||||
Export results with package path and usage instructions.
|
||||
"""
|
||||
args = {"skill_dir": skill_dir}
|
||||
if output_dir:
|
||||
args["output_dir"] = output_dir
|
||||
|
||||
result = await export_to_qdrant_impl(args)
|
||||
if isinstance(result, list) and result:
|
||||
return result[0].text if hasattr(result[0], "text") else str(result[0])
|
||||
return str(result)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MAIN ENTRY POINT
|
||||
# ============================================================================
|
||||
|
||||
@@ -9,6 +9,7 @@ Tools are organized by functionality:
|
||||
- packaging_tools: Skill packaging and upload
|
||||
- splitting_tools: Config splitting and router generation
|
||||
- source_tools: Config source management (fetch, submit, add/remove sources)
|
||||
- vector_db_tools: Vector database export (Weaviate, Chroma, FAISS, Qdrant)
|
||||
"""
|
||||
|
||||
# Import centralized version
|
||||
@@ -83,6 +84,18 @@ from .splitting_tools import (
|
||||
from .splitting_tools import (
|
||||
split_config as split_config_impl,
|
||||
)
|
||||
from .vector_db_tools import (
|
||||
export_to_chroma_impl,
|
||||
)
|
||||
from .vector_db_tools import (
|
||||
export_to_faiss_impl,
|
||||
)
|
||||
from .vector_db_tools import (
|
||||
export_to_qdrant_impl,
|
||||
)
|
||||
from .vector_db_tools import (
|
||||
export_to_weaviate_impl,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
@@ -114,4 +127,9 @@ __all__ = [
|
||||
"add_config_source_impl",
|
||||
"list_config_sources_impl",
|
||||
"remove_config_source_impl",
|
||||
# Vector database tools
|
||||
"export_to_weaviate_impl",
|
||||
"export_to_chroma_impl",
|
||||
"export_to_faiss_impl",
|
||||
"export_to_qdrant_impl",
|
||||
]
|
||||
|
||||
489
src/skill_seekers/mcp/tools/vector_db_tools.py
Normal file
489
src/skill_seekers/mcp/tools/vector_db_tools.py
Normal file
@@ -0,0 +1,489 @@
|
||||
"""
|
||||
Vector Database Tools for MCP Server.
|
||||
|
||||
Provides MCP tools for exporting skills to 4 vector databases:
|
||||
- Weaviate (hybrid search, 450K+ users)
|
||||
- Chroma (local-first, 800K+ developers)
|
||||
- FAISS (billion-scale, GPU-accelerated)
|
||||
- Qdrant (native filtering, 100K+ users)
|
||||
|
||||
Each tool provides a direct interface to its respective vector database adaptor.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
try:
|
||||
from mcp.types import TextContent
|
||||
except ImportError:
|
||||
# Graceful degradation for testing
|
||||
class TextContent:
|
||||
"""Fallback TextContent for when MCP is not installed"""
|
||||
|
||||
def __init__(self, type: str, text: str):
|
||||
self.type = type
|
||||
self.text = text
|
||||
|
||||
|
||||
# Path to CLI adaptors
|
||||
CLI_DIR = Path(__file__).parent.parent.parent / "cli"
|
||||
sys.path.insert(0, str(CLI_DIR))
|
||||
|
||||
try:
|
||||
from adaptors import get_adaptor
|
||||
except ImportError:
|
||||
get_adaptor = None # Will handle gracefully below
|
||||
|
||||
|
||||
async def export_to_weaviate_impl(args: dict) -> List[TextContent]:
|
||||
"""
|
||||
Export skill to Weaviate vector database format.
|
||||
|
||||
Weaviate is a popular cloud-native vector database with hybrid search
|
||||
(combining vector similarity + BM25 keyword search). Ideal for
|
||||
production RAG applications with 450K+ users.
|
||||
|
||||
Args:
|
||||
args: Dictionary with:
|
||||
- skill_dir (str): Path to skill directory (e.g., output/react/)
|
||||
- output_dir (str, optional): Output directory (default: same as skill_dir)
|
||||
|
||||
Returns:
|
||||
List of TextContent with export results
|
||||
|
||||
Example:
|
||||
{
|
||||
"skill_dir": "output/react",
|
||||
"output_dir": "output"
|
||||
}
|
||||
|
||||
Output Format:
|
||||
JSON file with Weaviate schema:
|
||||
- class_name: Weaviate class name
|
||||
- schema: Property definitions
|
||||
- objects: Document objects with vectors and metadata
|
||||
- config: Distance metric configuration
|
||||
"""
|
||||
if get_adaptor is None:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text="❌ Error: Could not import adaptors module. Please ensure skill-seekers is properly installed.",
|
||||
)
|
||||
]
|
||||
|
||||
skill_dir = Path(args["skill_dir"])
|
||||
output_dir = Path(args.get("output_dir", skill_dir.parent))
|
||||
|
||||
if not skill_dir.exists():
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"❌ Error: Skill directory not found: {skill_dir}\n\nPlease scrape documentation first using scrape_docs.",
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
# Get Weaviate adaptor
|
||||
adaptor = get_adaptor("weaviate")
|
||||
|
||||
# Package skill
|
||||
package_path = adaptor.package(skill_dir, output_dir)
|
||||
|
||||
# Success message
|
||||
result_text = f"""✅ Weaviate Export Complete!
|
||||
|
||||
📦 Package: {package_path.name}
|
||||
📁 Location: {package_path.parent}
|
||||
📊 Size: {package_path.stat().st_size:,} bytes
|
||||
|
||||
🔧 Next Steps:
|
||||
1. Upload to Weaviate:
|
||||
```python
|
||||
import weaviate
|
||||
import json
|
||||
|
||||
client = weaviate.Client("http://localhost:8080")
|
||||
data = json.load(open("{package_path}"))
|
||||
|
||||
# Create schema
|
||||
client.schema.create_class(data["schema"])
|
||||
|
||||
# Batch upload objects
|
||||
with client.batch as batch:
|
||||
for obj in data["objects"]:
|
||||
batch.add_data_object(obj["properties"], data["class_name"])
|
||||
```
|
||||
|
||||
2. Query with hybrid search:
|
||||
```python
|
||||
result = client.query.get(data["class_name"], ["content", "source"]) \\
|
||||
.with_hybrid("React hooks usage") \\
|
||||
.with_limit(5) \\
|
||||
.do()
|
||||
```
|
||||
|
||||
📚 Resources:
|
||||
- Weaviate Docs: https://weaviate.io/developers/weaviate
|
||||
- Hybrid Search: https://weaviate.io/developers/weaviate/search/hybrid
|
||||
"""
|
||||
|
||||
return [TextContent(type="text", text=result_text)]
|
||||
|
||||
except Exception as e:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"❌ Error exporting to Weaviate: {str(e)}\n\nPlease check that the skill directory contains valid documentation.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
async def export_to_chroma_impl(args: dict) -> List[TextContent]:
|
||||
"""
|
||||
Export skill to Chroma vector database format.
|
||||
|
||||
Chroma is a popular open-source embedding database designed for
|
||||
local-first development. Perfect for RAG prototyping with 800K+ developers.
|
||||
|
||||
Args:
|
||||
args: Dictionary with:
|
||||
- skill_dir (str): Path to skill directory (e.g., output/react/)
|
||||
- output_dir (str, optional): Output directory (default: same as skill_dir)
|
||||
|
||||
Returns:
|
||||
List of TextContent with export results
|
||||
|
||||
Example:
|
||||
{
|
||||
"skill_dir": "output/react",
|
||||
"output_dir": "output"
|
||||
}
|
||||
|
||||
Output Format:
|
||||
JSON file with Chroma collection data:
|
||||
- collection_name: Collection identifier
|
||||
- documents: List of document texts
|
||||
- metadatas: List of metadata dicts
|
||||
- ids: List of unique IDs
|
||||
"""
|
||||
if get_adaptor is None:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text="❌ Error: Could not import adaptors module.",
|
||||
)
|
||||
]
|
||||
|
||||
skill_dir = Path(args["skill_dir"])
|
||||
output_dir = Path(args.get("output_dir", skill_dir.parent))
|
||||
|
||||
if not skill_dir.exists():
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"❌ Error: Skill directory not found: {skill_dir}",
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
adaptor = get_adaptor("chroma")
|
||||
package_path = adaptor.package(skill_dir, output_dir)
|
||||
|
||||
result_text = f"""✅ Chroma Export Complete!
|
||||
|
||||
📦 Package: {package_path.name}
|
||||
📁 Location: {package_path.parent}
|
||||
📊 Size: {package_path.stat().st_size:,} bytes
|
||||
|
||||
🔧 Next Steps:
|
||||
1. Load into Chroma:
|
||||
```python
|
||||
import chromadb
|
||||
import json
|
||||
|
||||
client = chromadb.Client()
|
||||
data = json.load(open("{package_path}"))
|
||||
|
||||
# Create collection
|
||||
collection = client.create_collection(
|
||||
name=data["collection_name"],
|
||||
metadata={{"source": "skill-seekers"}}
|
||||
)
|
||||
|
||||
# Add documents
|
||||
collection.add(
|
||||
documents=data["documents"],
|
||||
metadatas=data["metadatas"],
|
||||
ids=data["ids"]
|
||||
)
|
||||
```
|
||||
|
||||
2. Query the collection:
|
||||
```python
|
||||
results = collection.query(
|
||||
query_texts=["How to use React hooks?"],
|
||||
n_results=5
|
||||
)
|
||||
```
|
||||
|
||||
📚 Resources:
|
||||
- Chroma Docs: https://docs.trychroma.com/
|
||||
- Getting Started: https://docs.trychroma.com/getting-started
|
||||
"""
|
||||
|
||||
return [TextContent(type="text", text=result_text)]
|
||||
|
||||
except Exception as e:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"❌ Error exporting to Chroma: {str(e)}",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
async def export_to_faiss_impl(args: dict) -> List[TextContent]:
|
||||
"""
|
||||
Export skill to FAISS vector index format.
|
||||
|
||||
FAISS (Facebook AI Similarity Search) is a library for efficient similarity
|
||||
search at billion-scale. Supports GPU acceleration for ultra-fast search.
|
||||
|
||||
Args:
|
||||
args: Dictionary with:
|
||||
- skill_dir (str): Path to skill directory (e.g., output/react/)
|
||||
- output_dir (str, optional): Output directory (default: same as skill_dir)
|
||||
- index_type (str, optional): FAISS index type (default: 'Flat')
|
||||
Options: 'Flat', 'IVF', 'HNSW'
|
||||
|
||||
Returns:
|
||||
List of TextContent with export results
|
||||
|
||||
Example:
|
||||
{
|
||||
"skill_dir": "output/react",
|
||||
"output_dir": "output",
|
||||
"index_type": "HNSW"
|
||||
}
|
||||
|
||||
Output Format:
|
||||
JSON file with FAISS data:
|
||||
- embeddings: List of embedding vectors
|
||||
- metadata: List of document metadata
|
||||
- index_config: FAISS index configuration
|
||||
"""
|
||||
if get_adaptor is None:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text="❌ Error: Could not import adaptors module.",
|
||||
)
|
||||
]
|
||||
|
||||
skill_dir = Path(args["skill_dir"])
|
||||
output_dir = Path(args.get("output_dir", skill_dir.parent))
|
||||
|
||||
if not skill_dir.exists():
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"❌ Error: Skill directory not found: {skill_dir}",
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
adaptor = get_adaptor("faiss")
|
||||
package_path = adaptor.package(skill_dir, output_dir)
|
||||
|
||||
result_text = f"""✅ FAISS Export Complete!
|
||||
|
||||
📦 Package: {package_path.name}
|
||||
📁 Location: {package_path.parent}
|
||||
📊 Size: {package_path.stat().st_size:,} bytes
|
||||
|
||||
🔧 Next Steps:
|
||||
1. Build FAISS index:
|
||||
```python
|
||||
import faiss
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
data = json.load(open("{package_path}"))
|
||||
embeddings = np.array(data["embeddings"], dtype="float32")
|
||||
|
||||
# Create index (choose based on scale)
|
||||
dimension = embeddings.shape[1]
|
||||
|
||||
# Option 1: Flat (exact search, small datasets)
|
||||
index = faiss.IndexFlatL2(dimension)
|
||||
|
||||
# Option 2: IVF (fast approximation, medium datasets)
|
||||
# quantizer = faiss.IndexFlatL2(dimension)
|
||||
# index = faiss.IndexIVFFlat(quantizer, dimension, 100)
|
||||
# index.train(embeddings)
|
||||
|
||||
# Option 3: HNSW (best quality approximation, large datasets)
|
||||
# index = faiss.IndexHNSWFlat(dimension, 32)
|
||||
|
||||
# Add vectors
|
||||
index.add(embeddings)
|
||||
```
|
||||
|
||||
2. Search:
|
||||
```python
|
||||
# Search for similar docs
|
||||
query = np.array([your_query_embedding], dtype="float32")
|
||||
distances, indices = index.search(query, k=5)
|
||||
|
||||
# Get metadata for results
|
||||
for i in indices[0]:
|
||||
print(data["metadata"][i])
|
||||
```
|
||||
|
||||
3. Save index:
|
||||
```python
|
||||
faiss.write_index(index, "react_docs.index")
|
||||
```
|
||||
|
||||
📚 Resources:
|
||||
- FAISS Wiki: https://github.com/facebookresearch/faiss/wiki
|
||||
- GPU Support: https://github.com/facebookresearch/faiss/wiki/Faiss-on-the-GPU
|
||||
"""
|
||||
|
||||
return [TextContent(type="text", text=result_text)]
|
||||
|
||||
except Exception as e:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"❌ Error exporting to FAISS: {str(e)}",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
async def export_to_qdrant_impl(args: dict) -> List[TextContent]:
|
||||
"""
|
||||
Export skill to Qdrant vector database format.
|
||||
|
||||
Qdrant is a modern vector database with native payload filtering and
|
||||
high-performance search. Ideal for production RAG with 100K+ users.
|
||||
|
||||
Args:
|
||||
args: Dictionary with:
|
||||
- skill_dir (str): Path to skill directory (e.g., output/react/)
|
||||
- output_dir (str, optional): Output directory (default: same as skill_dir)
|
||||
|
||||
Returns:
|
||||
List of TextContent with export results
|
||||
|
||||
Example:
|
||||
{
|
||||
"skill_dir": "output/react",
|
||||
"output_dir": "output"
|
||||
}
|
||||
|
||||
Output Format:
|
||||
JSON file with Qdrant collection data:
|
||||
- collection_name: Collection identifier
|
||||
- points: List of points with id, vector, payload
|
||||
- config: Vector configuration
|
||||
"""
|
||||
if get_adaptor is None:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text="❌ Error: Could not import adaptors module.",
|
||||
)
|
||||
]
|
||||
|
||||
skill_dir = Path(args["skill_dir"])
|
||||
output_dir = Path(args.get("output_dir", skill_dir.parent))
|
||||
|
||||
if not skill_dir.exists():
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"❌ Error: Skill directory not found: {skill_dir}",
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
adaptor = get_adaptor("qdrant")
|
||||
package_path = adaptor.package(skill_dir, output_dir)
|
||||
|
||||
result_text = f"""✅ Qdrant Export Complete!
|
||||
|
||||
📦 Package: {package_path.name}
|
||||
📁 Location: {package_path.parent}
|
||||
📊 Size: {package_path.stat().st_size:,} bytes
|
||||
|
||||
🔧 Next Steps:
|
||||
1. Upload to Qdrant:
|
||||
```python
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
import json
|
||||
|
||||
client = QdrantClient("localhost", port=6333)
|
||||
data = json.load(open("{package_path}"))
|
||||
|
||||
# Create collection
|
||||
client.create_collection(
|
||||
collection_name=data["collection_name"],
|
||||
vectors_config=VectorParams(
|
||||
size=data["config"]["vector_size"],
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
|
||||
# Upload points
|
||||
client.upsert(
|
||||
collection_name=data["collection_name"],
|
||||
points=data["points"]
|
||||
)
|
||||
```
|
||||
|
||||
2. Search with filters:
|
||||
```python
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
||||
|
||||
results = client.search(
|
||||
collection_name=data["collection_name"],
|
||||
query_vector=your_query_vector,
|
||||
query_filter=Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="category",
|
||||
match=MatchValue(value="getting_started")
|
||||
)
|
||||
]
|
||||
),
|
||||
limit=5
|
||||
)
|
||||
```
|
||||
|
||||
📚 Resources:
|
||||
- Qdrant Docs: https://qdrant.tech/documentation/
|
||||
- Filtering: https://qdrant.tech/documentation/concepts/filtering/
|
||||
"""
|
||||
|
||||
return [TextContent(type="text", text=result_text)]
|
||||
|
||||
except Exception as e:
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"❌ Error exporting to Qdrant: {str(e)}",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# Export all implementations
|
||||
__all__ = [
|
||||
"export_to_weaviate_impl",
|
||||
"export_to_chroma_impl",
|
||||
"export_to_faiss_impl",
|
||||
"export_to_qdrant_impl",
|
||||
]
|
||||
40
src/skill_seekers/sync/__init__.py
Normal file
40
src/skill_seekers/sync/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Real-time documentation sync system.
|
||||
|
||||
Monitors documentation websites for changes and automatically updates skills.
|
||||
|
||||
Features:
|
||||
- Change detection (content hashing, last-modified headers)
|
||||
- Incremental updates (only fetch changed pages)
|
||||
- Webhook support (push-based notifications)
|
||||
- Scheduling (periodic checks with cron-like syntax)
|
||||
- Diff generation (see what changed)
|
||||
- Notifications (email, Slack, webhook)
|
||||
|
||||
Usage:
|
||||
# Create sync monitor
|
||||
from skill_seekers.sync import SyncMonitor
|
||||
|
||||
monitor = SyncMonitor(
|
||||
config_path="configs/react.json",
|
||||
check_interval=3600 # 1 hour
|
||||
)
|
||||
|
||||
# Start monitoring
|
||||
monitor.start()
|
||||
|
||||
# Or run once
|
||||
changes = monitor.check_for_updates()
|
||||
"""
|
||||
|
||||
from .monitor import SyncMonitor
|
||||
from .detector import ChangeDetector
|
||||
from .models import SyncConfig, ChangeReport, PageChange
|
||||
|
||||
__all__ = [
|
||||
'SyncMonitor',
|
||||
'ChangeDetector',
|
||||
'SyncConfig',
|
||||
'ChangeReport',
|
||||
'PageChange',
|
||||
]
|
||||
321
src/skill_seekers/sync/detector.py
Normal file
321
src/skill_seekers/sync/detector.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Change detection for documentation pages.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import difflib
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
from .models import PageChange, ChangeType, ChangeReport
|
||||
|
||||
|
||||
class ChangeDetector:
|
||||
"""
|
||||
Detects changes in documentation pages.
|
||||
|
||||
Uses multiple strategies:
|
||||
1. Content hashing (SHA-256)
|
||||
2. Last-Modified headers
|
||||
3. ETag headers
|
||||
4. Content diffing
|
||||
|
||||
Examples:
|
||||
detector = ChangeDetector()
|
||||
|
||||
# Check single page
|
||||
change = detector.check_page(
|
||||
url="https://react.dev/learn",
|
||||
old_hash="abc123"
|
||||
)
|
||||
|
||||
# Generate diff
|
||||
diff = detector.generate_diff(old_content, new_content)
|
||||
|
||||
# Check multiple pages
|
||||
changes = detector.check_pages(urls, previous_state)
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: int = 30):
|
||||
"""
|
||||
Initialize change detector.
|
||||
|
||||
Args:
|
||||
timeout: Request timeout in seconds
|
||||
"""
|
||||
self.timeout = timeout
|
||||
|
||||
def compute_hash(self, content: str) -> str:
|
||||
"""
|
||||
Compute SHA-256 hash of content.
|
||||
|
||||
Args:
|
||||
content: Page content
|
||||
|
||||
Returns:
|
||||
Hexadecimal hash string
|
||||
"""
|
||||
return hashlib.sha256(content.encode('utf-8')).hexdigest()
|
||||
|
||||
def fetch_page(self, url: str) -> Tuple[str, Dict[str, str]]:
|
||||
"""
|
||||
Fetch page content and metadata.
|
||||
|
||||
Args:
|
||||
url: Page URL
|
||||
|
||||
Returns:
|
||||
Tuple of (content, metadata)
|
||||
metadata includes: last-modified, etag, content-type
|
||||
|
||||
Raises:
|
||||
requests.RequestException: If fetch fails
|
||||
"""
|
||||
response = requests.get(
|
||||
url,
|
||||
timeout=self.timeout,
|
||||
headers={'User-Agent': 'SkillSeekers-Sync/1.0'}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
metadata = {
|
||||
'last-modified': response.headers.get('Last-Modified'),
|
||||
'etag': response.headers.get('ETag'),
|
||||
'content-type': response.headers.get('Content-Type'),
|
||||
'content-length': response.headers.get('Content-Length'),
|
||||
}
|
||||
|
||||
return response.text, metadata
|
||||
|
||||
def check_page(
|
||||
self,
|
||||
url: str,
|
||||
old_hash: Optional[str] = None,
|
||||
generate_diff: bool = False,
|
||||
old_content: Optional[str] = None
|
||||
) -> PageChange:
|
||||
"""
|
||||
Check if page has changed.
|
||||
|
||||
Args:
|
||||
url: Page URL
|
||||
old_hash: Previous content hash
|
||||
generate_diff: Whether to generate diff
|
||||
old_content: Previous content (for diff generation)
|
||||
|
||||
Returns:
|
||||
PageChange object
|
||||
|
||||
Raises:
|
||||
requests.RequestException: If fetch fails
|
||||
"""
|
||||
try:
|
||||
content, metadata = self.fetch_page(url)
|
||||
new_hash = self.compute_hash(content)
|
||||
|
||||
# Determine change type
|
||||
if old_hash is None:
|
||||
change_type = ChangeType.ADDED
|
||||
elif old_hash == new_hash:
|
||||
change_type = ChangeType.UNCHANGED
|
||||
else:
|
||||
change_type = ChangeType.MODIFIED
|
||||
|
||||
# Generate diff if requested
|
||||
diff = None
|
||||
if generate_diff and old_content and change_type == ChangeType.MODIFIED:
|
||||
diff = self.generate_diff(old_content, content)
|
||||
|
||||
return PageChange(
|
||||
url=url,
|
||||
change_type=change_type,
|
||||
old_hash=old_hash,
|
||||
new_hash=new_hash,
|
||||
diff=diff,
|
||||
detected_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
except requests.RequestException as e:
|
||||
# Page might be deleted or temporarily unavailable
|
||||
return PageChange(
|
||||
url=url,
|
||||
change_type=ChangeType.DELETED,
|
||||
old_hash=old_hash,
|
||||
new_hash=None,
|
||||
detected_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
def check_pages(
|
||||
self,
|
||||
urls: List[str],
|
||||
previous_hashes: Dict[str, str],
|
||||
generate_diffs: bool = False
|
||||
) -> ChangeReport:
|
||||
"""
|
||||
Check multiple pages for changes.
|
||||
|
||||
Args:
|
||||
urls: List of URLs to check
|
||||
previous_hashes: URL -> hash mapping from previous state
|
||||
generate_diffs: Whether to generate diffs
|
||||
|
||||
Returns:
|
||||
ChangeReport with all detected changes
|
||||
"""
|
||||
added = []
|
||||
modified = []
|
||||
deleted = []
|
||||
unchanged_count = 0
|
||||
|
||||
# Check each URL
|
||||
checked_urls = set()
|
||||
for url in urls:
|
||||
checked_urls.add(url)
|
||||
old_hash = previous_hashes.get(url)
|
||||
|
||||
change = self.check_page(url, old_hash, generate_diff=generate_diffs)
|
||||
|
||||
if change.change_type == ChangeType.ADDED:
|
||||
added.append(change)
|
||||
elif change.change_type == ChangeType.MODIFIED:
|
||||
modified.append(change)
|
||||
elif change.change_type == ChangeType.UNCHANGED:
|
||||
unchanged_count += 1
|
||||
|
||||
# Check for deleted pages (in previous state but not in current)
|
||||
for url, old_hash in previous_hashes.items():
|
||||
if url not in checked_urls:
|
||||
deleted.append(PageChange(
|
||||
url=url,
|
||||
change_type=ChangeType.DELETED,
|
||||
old_hash=old_hash,
|
||||
new_hash=None,
|
||||
detected_at=datetime.utcnow()
|
||||
))
|
||||
|
||||
return ChangeReport(
|
||||
skill_name="unknown", # To be set by caller
|
||||
total_pages=len(urls),
|
||||
added=added,
|
||||
modified=modified,
|
||||
deleted=deleted,
|
||||
unchanged=unchanged_count,
|
||||
checked_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
def generate_diff(self, old_content: str, new_content: str) -> str:
|
||||
"""
|
||||
Generate unified diff between old and new content.
|
||||
|
||||
Args:
|
||||
old_content: Original content
|
||||
new_content: New content
|
||||
|
||||
Returns:
|
||||
Unified diff string
|
||||
"""
|
||||
old_lines = old_content.splitlines(keepends=True)
|
||||
new_lines = new_content.splitlines(keepends=True)
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
old_lines,
|
||||
new_lines,
|
||||
fromfile='old',
|
||||
tofile='new',
|
||||
lineterm=''
|
||||
)
|
||||
|
||||
return ''.join(diff)
|
||||
|
||||
def generate_summary_diff(self, old_content: str, new_content: str) -> str:
|
||||
"""
|
||||
Generate human-readable diff summary.
|
||||
|
||||
Args:
|
||||
old_content: Original content
|
||||
new_content: New content
|
||||
|
||||
Returns:
|
||||
Summary string with added/removed line counts
|
||||
"""
|
||||
old_lines = old_content.splitlines()
|
||||
new_lines = new_content.splitlines()
|
||||
|
||||
diff = difflib.unified_diff(old_lines, new_lines)
|
||||
diff_lines = list(diff)
|
||||
|
||||
added = sum(1 for line in diff_lines if line.startswith('+') and not line.startswith('+++'))
|
||||
removed = sum(1 for line in diff_lines if line.startswith('-') and not line.startswith('---'))
|
||||
|
||||
return f"+{added} -{removed} lines"
|
||||
|
||||
def check_header_changes(
|
||||
self,
|
||||
url: str,
|
||||
old_modified: Optional[str] = None,
|
||||
old_etag: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Quick check using HTTP headers (no content download).
|
||||
|
||||
Args:
|
||||
url: Page URL
|
||||
old_modified: Previous Last-Modified header
|
||||
old_etag: Previous ETag header
|
||||
|
||||
Returns:
|
||||
True if headers indicate change, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Use HEAD request for efficiency
|
||||
response = requests.head(
|
||||
url,
|
||||
timeout=self.timeout,
|
||||
headers={'User-Agent': 'SkillSeekers-Sync/1.0'}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
new_modified = response.headers.get('Last-Modified')
|
||||
new_etag = response.headers.get('ETag')
|
||||
|
||||
# Check if headers indicate change
|
||||
if old_modified and new_modified and old_modified != new_modified:
|
||||
return True
|
||||
|
||||
if old_etag and new_etag and old_etag != new_etag:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except requests.RequestException:
|
||||
# If HEAD request fails, assume change (will be verified with GET)
|
||||
return True
|
||||
|
||||
def batch_check_headers(
|
||||
self,
|
||||
urls: List[str],
|
||||
previous_metadata: Dict[str, Dict[str, str]]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Batch check URLs using headers only.
|
||||
|
||||
Args:
|
||||
urls: URLs to check
|
||||
previous_metadata: URL -> metadata mapping
|
||||
|
||||
Returns:
|
||||
List of URLs that likely changed
|
||||
"""
|
||||
changed_urls = []
|
||||
|
||||
for url in urls:
|
||||
old_meta = previous_metadata.get(url, {})
|
||||
old_modified = old_meta.get('last-modified')
|
||||
old_etag = old_meta.get('etag')
|
||||
|
||||
if self.check_header_changes(url, old_modified, old_etag):
|
||||
changed_urls.append(url)
|
||||
|
||||
return changed_urls
|
||||
164
src/skill_seekers/sync/models.py
Normal file
164
src/skill_seekers/sync/models.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Pydantic models for sync system.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChangeType(str, Enum):
|
||||
"""Type of change detected."""
|
||||
ADDED = "added"
|
||||
MODIFIED = "modified"
|
||||
DELETED = "deleted"
|
||||
UNCHANGED = "unchanged"
|
||||
|
||||
|
||||
class PageChange(BaseModel):
|
||||
"""Represents a change to a single page."""
|
||||
|
||||
url: str = Field(..., description="Page URL")
|
||||
change_type: ChangeType = Field(..., description="Type of change")
|
||||
old_hash: Optional[str] = Field(None, description="Previous content hash")
|
||||
new_hash: Optional[str] = Field(None, description="New content hash")
|
||||
diff: Optional[str] = Field(None, description="Content diff (if available)")
|
||||
detected_at: datetime = Field(
|
||||
default_factory=datetime.utcnow,
|
||||
description="When change was detected"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"url": "https://react.dev/learn/thinking-in-react",
|
||||
"change_type": "modified",
|
||||
"old_hash": "abc123",
|
||||
"new_hash": "def456",
|
||||
"diff": "@@ -10,3 +10,4 @@\n+New content here",
|
||||
"detected_at": "2024-01-15T10:30:00Z"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ChangeReport(BaseModel):
|
||||
"""Report of all changes detected."""
|
||||
|
||||
skill_name: str = Field(..., description="Skill name")
|
||||
total_pages: int = Field(..., description="Total pages checked")
|
||||
added: List[PageChange] = Field(default_factory=list, description="Added pages")
|
||||
modified: List[PageChange] = Field(default_factory=list, description="Modified pages")
|
||||
deleted: List[PageChange] = Field(default_factory=list, description="Deleted pages")
|
||||
unchanged: int = Field(0, description="Number of unchanged pages")
|
||||
checked_at: datetime = Field(
|
||||
default_factory=datetime.utcnow,
|
||||
description="When check was performed"
|
||||
)
|
||||
|
||||
@property
|
||||
def has_changes(self) -> bool:
|
||||
"""Check if any changes were detected."""
|
||||
return bool(self.added or self.modified or self.deleted)
|
||||
|
||||
@property
|
||||
def change_count(self) -> int:
|
||||
"""Total number of changes."""
|
||||
return len(self.added) + len(self.modified) + len(self.deleted)
|
||||
|
||||
|
||||
class SyncConfig(BaseModel):
|
||||
"""Configuration for sync monitoring."""
|
||||
|
||||
skill_config: str = Field(..., description="Path to skill config file")
|
||||
check_interval: int = Field(
|
||||
default=3600,
|
||||
description="Check interval in seconds (default: 1 hour)"
|
||||
)
|
||||
enabled: bool = Field(default=True, description="Whether sync is enabled")
|
||||
auto_update: bool = Field(
|
||||
default=False,
|
||||
description="Automatically rebuild skill on changes"
|
||||
)
|
||||
notify_on_change: bool = Field(
|
||||
default=True,
|
||||
description="Send notifications on changes"
|
||||
)
|
||||
notification_channels: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Notification channels (email, slack, webhook)"
|
||||
)
|
||||
webhook_url: Optional[str] = Field(
|
||||
None,
|
||||
description="Webhook URL for change notifications"
|
||||
)
|
||||
email_recipients: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Email recipients for notifications"
|
||||
)
|
||||
slack_webhook: Optional[str] = Field(
|
||||
None,
|
||||
description="Slack webhook URL"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"skill_config": "configs/react.json",
|
||||
"check_interval": 3600,
|
||||
"enabled": True,
|
||||
"auto_update": False,
|
||||
"notify_on_change": True,
|
||||
"notification_channels": ["slack", "webhook"],
|
||||
"webhook_url": "https://example.com/webhook",
|
||||
"slack_webhook": "https://hooks.slack.com/services/..."
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class SyncState(BaseModel):
|
||||
"""Current state of sync monitoring."""
|
||||
|
||||
skill_name: str = Field(..., description="Skill name")
|
||||
last_check: Optional[datetime] = Field(None, description="Last check time")
|
||||
last_change: Optional[datetime] = Field(None, description="Last change detected")
|
||||
total_checks: int = Field(default=0, description="Total checks performed")
|
||||
total_changes: int = Field(default=0, description="Total changes detected")
|
||||
page_hashes: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="URL -> content hash mapping"
|
||||
)
|
||||
status: str = Field(default="idle", description="Current status")
|
||||
error: Optional[str] = Field(None, description="Last error message")
|
||||
|
||||
|
||||
class WebhookPayload(BaseModel):
|
||||
"""Payload for webhook notifications."""
|
||||
|
||||
event: str = Field(..., description="Event type (change_detected, sync_complete)")
|
||||
skill_name: str = Field(..., description="Skill name")
|
||||
timestamp: datetime = Field(
|
||||
default_factory=datetime.utcnow,
|
||||
description="Event timestamp"
|
||||
)
|
||||
changes: Optional[ChangeReport] = Field(None, description="Change report")
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"event": "change_detected",
|
||||
"skill_name": "react",
|
||||
"timestamp": "2024-01-15T10:30:00Z",
|
||||
"changes": {
|
||||
"total_pages": 150,
|
||||
"added": [],
|
||||
"modified": [{"url": "https://react.dev/learn"}],
|
||||
"deleted": []
|
||||
},
|
||||
"metadata": {"source": "periodic_check"}
|
||||
}
|
||||
}
|
||||
267
src/skill_seekers/sync/monitor.py
Normal file
267
src/skill_seekers/sync/monitor.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Sync monitor for continuous documentation monitoring.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Callable
|
||||
from datetime import datetime
|
||||
import schedule
|
||||
|
||||
from .detector import ChangeDetector
|
||||
from .models import SyncConfig, SyncState, ChangeReport, WebhookPayload
|
||||
from .notifier import Notifier
|
||||
|
||||
|
||||
class SyncMonitor:
|
||||
"""
|
||||
Monitors documentation for changes and triggers updates.
|
||||
|
||||
Features:
|
||||
- Continuous monitoring with configurable intervals
|
||||
- State persistence (resume after restart)
|
||||
- Change detection and diff generation
|
||||
- Notification system
|
||||
- Auto-update capability
|
||||
|
||||
Examples:
|
||||
# Basic usage
|
||||
monitor = SyncMonitor(
|
||||
config_path="configs/react.json",
|
||||
check_interval=3600
|
||||
)
|
||||
monitor.start()
|
||||
|
||||
# With auto-update
|
||||
monitor = SyncMonitor(
|
||||
config_path="configs/react.json",
|
||||
auto_update=True,
|
||||
on_change=lambda report: print(f"Detected {report.change_count} changes")
|
||||
)
|
||||
|
||||
# Run once
|
||||
changes = monitor.check_now()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str,
|
||||
check_interval: int = 3600,
|
||||
auto_update: bool = False,
|
||||
state_file: Optional[str] = None,
|
||||
on_change: Optional[Callable[[ChangeReport], None]] = None
|
||||
):
|
||||
"""
|
||||
Initialize sync monitor.
|
||||
|
||||
Args:
|
||||
config_path: Path to skill config file
|
||||
check_interval: Check interval in seconds
|
||||
auto_update: Auto-rebuild skill on changes
|
||||
state_file: Path to state file (default: {skill_name}_sync.json)
|
||||
on_change: Callback function for change events
|
||||
"""
|
||||
self.config_path = Path(config_path)
|
||||
self.check_interval = check_interval
|
||||
self.auto_update = auto_update
|
||||
self.on_change = on_change
|
||||
|
||||
# Load skill config
|
||||
with open(self.config_path) as f:
|
||||
self.skill_config = json.load(f)
|
||||
|
||||
self.skill_name = self.skill_config.get('name', 'unknown')
|
||||
|
||||
# State file
|
||||
if state_file:
|
||||
self.state_file = Path(state_file)
|
||||
else:
|
||||
self.state_file = Path(f"{self.skill_name}_sync.json")
|
||||
|
||||
# Initialize components
|
||||
self.detector = ChangeDetector()
|
||||
self.notifier = Notifier()
|
||||
|
||||
# Load state
|
||||
self.state = self._load_state()
|
||||
|
||||
# Threading
|
||||
self._running = False
|
||||
self._thread = None
|
||||
|
||||
def _load_state(self) -> SyncState:
|
||||
"""Load state from file or create new."""
|
||||
if self.state_file.exists():
|
||||
with open(self.state_file) as f:
|
||||
data = json.load(f)
|
||||
# Convert datetime strings back
|
||||
if data.get('last_check'):
|
||||
data['last_check'] = datetime.fromisoformat(data['last_check'])
|
||||
if data.get('last_change'):
|
||||
data['last_change'] = datetime.fromisoformat(data['last_change'])
|
||||
return SyncState(**data)
|
||||
else:
|
||||
return SyncState(skill_name=self.skill_name)
|
||||
|
||||
def _save_state(self):
|
||||
"""Save current state to file."""
|
||||
# Convert datetime to ISO format
|
||||
data = self.state.dict()
|
||||
if data.get('last_check'):
|
||||
data['last_check'] = data['last_check'].isoformat()
|
||||
if data.get('last_change'):
|
||||
data['last_change'] = data['last_change'].isoformat()
|
||||
|
||||
with open(self.state_file, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def check_now(self, generate_diffs: bool = False) -> ChangeReport:
|
||||
"""
|
||||
Check for changes now (synchronous).
|
||||
|
||||
Args:
|
||||
generate_diffs: Whether to generate content diffs
|
||||
|
||||
Returns:
|
||||
ChangeReport with detected changes
|
||||
"""
|
||||
self.state.status = "checking"
|
||||
self._save_state()
|
||||
|
||||
try:
|
||||
# Get URLs to check from config
|
||||
base_url = self.skill_config.get('base_url')
|
||||
# TODO: In real implementation, get actual URLs from scraper
|
||||
|
||||
# For now, simulate with base URL only
|
||||
urls = [base_url] if base_url else []
|
||||
|
||||
# Check for changes
|
||||
report = self.detector.check_pages(
|
||||
urls=urls,
|
||||
previous_hashes=self.state.page_hashes,
|
||||
generate_diffs=generate_diffs
|
||||
)
|
||||
report.skill_name = self.skill_name
|
||||
|
||||
# Update state
|
||||
self.state.last_check = datetime.utcnow()
|
||||
self.state.total_checks += 1
|
||||
|
||||
if report.has_changes:
|
||||
self.state.last_change = datetime.utcnow()
|
||||
self.state.total_changes += report.change_count
|
||||
|
||||
# Update hashes for modified pages
|
||||
for change in report.added + report.modified:
|
||||
if change.new_hash:
|
||||
self.state.page_hashes[change.url] = change.new_hash
|
||||
|
||||
# Remove deleted pages
|
||||
for change in report.deleted:
|
||||
self.state.page_hashes.pop(change.url, None)
|
||||
|
||||
# Trigger callback
|
||||
if self.on_change:
|
||||
self.on_change(report)
|
||||
|
||||
# Send notifications
|
||||
self._notify(report)
|
||||
|
||||
# Auto-update if enabled
|
||||
if self.auto_update:
|
||||
self._trigger_update(report)
|
||||
|
||||
self.state.status = "idle"
|
||||
self.state.error = None
|
||||
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
self.state.status = "error"
|
||||
self.state.error = str(e)
|
||||
raise
|
||||
finally:
|
||||
self._save_state()
|
||||
|
||||
def _notify(self, report: ChangeReport):
|
||||
"""Send notifications about changes."""
|
||||
payload = WebhookPayload(
|
||||
event="change_detected",
|
||||
skill_name=self.skill_name,
|
||||
changes=report,
|
||||
metadata={"auto_update": self.auto_update}
|
||||
)
|
||||
|
||||
self.notifier.send(payload)
|
||||
|
||||
def _trigger_update(self, report: ChangeReport):
|
||||
"""Trigger skill rebuild."""
|
||||
print(f"🔄 Auto-updating {self.skill_name} due to {report.change_count} changes...")
|
||||
# TODO: Integrate with doc_scraper to rebuild skill
|
||||
# For now, just log
|
||||
print(f" Added: {len(report.added)}")
|
||||
print(f" Modified: {len(report.modified)}")
|
||||
print(f" Deleted: {len(report.deleted)}")
|
||||
|
||||
def start(self):
|
||||
"""Start continuous monitoring."""
|
||||
if self._running:
|
||||
raise RuntimeError("Monitor is already running")
|
||||
|
||||
self._running = True
|
||||
|
||||
# Schedule checks
|
||||
schedule.every(self.check_interval).seconds.do(
|
||||
lambda: self.check_now()
|
||||
)
|
||||
|
||||
# Run in thread
|
||||
def run_schedule():
|
||||
while self._running:
|
||||
schedule.run_pending()
|
||||
time.sleep(1)
|
||||
|
||||
self._thread = threading.Thread(target=run_schedule, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
print(f"✅ Started monitoring {self.skill_name} (every {self.check_interval}s)")
|
||||
|
||||
# Run first check immediately
|
||||
self.check_now()
|
||||
|
||||
def stop(self):
|
||||
"""Stop monitoring."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
|
||||
print(f"🛑 Stopped monitoring {self.skill_name}")
|
||||
|
||||
def stats(self) -> Dict:
|
||||
"""Get monitoring statistics."""
|
||||
return {
|
||||
"skill_name": self.skill_name,
|
||||
"status": self.state.status,
|
||||
"last_check": self.state.last_check.isoformat() if self.state.last_check else None,
|
||||
"last_change": self.state.last_change.isoformat() if self.state.last_change else None,
|
||||
"total_checks": self.state.total_checks,
|
||||
"total_changes": self.state.total_changes,
|
||||
"tracked_pages": len(self.state.page_hashes),
|
||||
"running": self._running,
|
||||
}
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.stop()
|
||||
144
src/skill_seekers/sync/notifier.py
Normal file
144
src/skill_seekers/sync/notifier.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Notification system for sync events.
|
||||
"""
|
||||
|
||||
import os
|
||||
import requests
|
||||
from typing import Optional, List
|
||||
from .models import WebhookPayload
|
||||
|
||||
|
||||
class Notifier:
|
||||
"""
|
||||
Send notifications about sync events.
|
||||
|
||||
Supports:
|
||||
- Webhook (HTTP POST)
|
||||
- Slack (via webhook)
|
||||
- Email (SMTP) - TODO
|
||||
- Console (stdout)
|
||||
|
||||
Examples:
|
||||
notifier = Notifier()
|
||||
|
||||
payload = WebhookPayload(
|
||||
event="change_detected",
|
||||
skill_name="react",
|
||||
changes=report
|
||||
)
|
||||
|
||||
notifier.send(payload)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
webhook_url: Optional[str] = None,
|
||||
slack_webhook: Optional[str] = None,
|
||||
email_recipients: Optional[List[str]] = None,
|
||||
console: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize notifier.
|
||||
|
||||
Args:
|
||||
webhook_url: Webhook URL for HTTP notifications
|
||||
slack_webhook: Slack webhook URL
|
||||
email_recipients: List of email recipients
|
||||
console: Whether to print to console
|
||||
"""
|
||||
self.webhook_url = webhook_url or os.getenv('SYNC_WEBHOOK_URL')
|
||||
self.slack_webhook = slack_webhook or os.getenv('SLACK_WEBHOOK_URL')
|
||||
self.email_recipients = email_recipients or []
|
||||
self.console = console
|
||||
|
||||
def send(self, payload: WebhookPayload):
|
||||
"""
|
||||
Send notification via all configured channels.
|
||||
|
||||
Args:
|
||||
payload: Notification payload
|
||||
"""
|
||||
if self.console:
|
||||
self._send_console(payload)
|
||||
|
||||
if self.webhook_url:
|
||||
self._send_webhook(payload)
|
||||
|
||||
if self.slack_webhook:
|
||||
self._send_slack(payload)
|
||||
|
||||
if self.email_recipients:
|
||||
self._send_email(payload)
|
||||
|
||||
def _send_console(self, payload: WebhookPayload):
|
||||
"""Print to console."""
|
||||
print(f"\n📢 {payload.event.upper()}: {payload.skill_name}")
|
||||
|
||||
if payload.changes:
|
||||
changes = payload.changes
|
||||
if changes.has_changes:
|
||||
print(f" Changes detected: {changes.change_count}")
|
||||
if changes.added:
|
||||
print(f" ✅ Added: {len(changes.added)} pages")
|
||||
if changes.modified:
|
||||
print(f" ✏️ Modified: {len(changes.modified)} pages")
|
||||
if changes.deleted:
|
||||
print(f" ❌ Deleted: {len(changes.deleted)} pages")
|
||||
else:
|
||||
print(" No changes detected")
|
||||
|
||||
def _send_webhook(self, payload: WebhookPayload):
|
||||
"""Send to generic webhook."""
|
||||
try:
|
||||
response = requests.post(
|
||||
self.webhook_url,
|
||||
json=payload.dict(),
|
||||
headers={'Content-Type': 'application/json'},
|
||||
timeout=10
|
||||
)
|
||||
response.raise_for_status()
|
||||
print(f"✅ Webhook notification sent to {self.webhook_url}")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to send webhook: {e}")
|
||||
|
||||
def _send_slack(self, payload: WebhookPayload):
|
||||
"""Send to Slack via webhook."""
|
||||
try:
|
||||
# Format Slack message
|
||||
text = f"*{payload.event.upper()}*: {payload.skill_name}"
|
||||
|
||||
if payload.changes and payload.changes.has_changes:
|
||||
changes = payload.changes
|
||||
text += f"\n• Changes: {changes.change_count}"
|
||||
text += f"\n• Added: {len(changes.added)}"
|
||||
text += f"\n• Modified: {len(changes.modified)}"
|
||||
text += f"\n• Deleted: {len(changes.deleted)}"
|
||||
|
||||
# Add URLs of changed pages
|
||||
if changes.modified:
|
||||
text += "\n\n*Modified Pages:*"
|
||||
for change in changes.modified[:5]: # Limit to 5
|
||||
text += f"\n• {change.url}"
|
||||
if len(changes.modified) > 5:
|
||||
text += f"\n• ...and {len(changes.modified) - 5} more"
|
||||
|
||||
slack_payload = {
|
||||
"text": text,
|
||||
"username": "Skill Seekers Sync",
|
||||
"icon_emoji": ":books:"
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.slack_webhook,
|
||||
json=slack_payload,
|
||||
timeout=10
|
||||
)
|
||||
response.raise_for_status()
|
||||
print("✅ Slack notification sent")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to send Slack notification: {e}")
|
||||
|
||||
def _send_email(self, payload: WebhookPayload):
|
||||
"""Send email notification."""
|
||||
# TODO: Implement SMTP email sending
|
||||
print(f"📧 Email notification (not implemented): {self.email_recipients}")
|
||||
Reference in New Issue
Block a user