style: Format all Python files with ruff
- Formatted 103 files to comply with ruff format requirements - No code logic changes, only formatting/whitespace - Fixes CI formatting check failures
This commit is contained in:
@@ -23,9 +23,9 @@ from .generator import EmbeddingGenerator
|
||||
from .cache import EmbeddingCache
|
||||
|
||||
__all__ = [
|
||||
'EmbeddingRequest',
|
||||
'EmbeddingResponse',
|
||||
'BatchEmbeddingRequest',
|
||||
'EmbeddingGenerator',
|
||||
'EmbeddingCache',
|
||||
"EmbeddingRequest",
|
||||
"EmbeddingResponse",
|
||||
"BatchEmbeddingRequest",
|
||||
"EmbeddingGenerator",
|
||||
"EmbeddingCache",
|
||||
]
|
||||
|
||||
@@ -74,12 +74,7 @@ class EmbeddingCache:
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def set(
|
||||
self,
|
||||
hash_key: str,
|
||||
embedding: list[float],
|
||||
model: str
|
||||
) -> None:
|
||||
def set(self, hash_key: str, embedding: list[float], model: str) -> None:
|
||||
"""
|
||||
Store embedding in cache.
|
||||
|
||||
@@ -94,11 +89,14 @@ class EmbeddingCache:
|
||||
embedding_json = json.dumps(embedding)
|
||||
dimensions = len(embedding)
|
||||
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO embeddings
|
||||
(hash, embedding, model, dimensions, created_at, accessed_at, access_count)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 1)
|
||||
""", (hash_key, embedding_json, model, dimensions, now, now))
|
||||
""",
|
||||
(hash_key, embedding_json, model, dimensions, now, now),
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
@@ -115,11 +113,14 @@ class EmbeddingCache:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# Get embedding
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT embedding, created_at
|
||||
FROM embeddings
|
||||
WHERE hash = ?
|
||||
""", (hash_key,))
|
||||
""",
|
||||
(hash_key,),
|
||||
)
|
||||
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
@@ -136,11 +137,14 @@ class EmbeddingCache:
|
||||
|
||||
# Update access stats
|
||||
now = datetime.utcnow().isoformat()
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE embeddings
|
||||
SET accessed_at = ?, access_count = access_count + 1
|
||||
WHERE hash = ?
|
||||
""", (now, hash_key))
|
||||
""",
|
||||
(now, hash_key),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
return json.loads(embedding_json)
|
||||
@@ -178,11 +182,14 @@ class EmbeddingCache:
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT created_at
|
||||
FROM embeddings
|
||||
WHERE hash = ?
|
||||
""", (hash_key,))
|
||||
""",
|
||||
(hash_key,),
|
||||
)
|
||||
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
@@ -206,10 +213,13 @@ class EmbeddingCache:
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM embeddings
|
||||
WHERE hash = ?
|
||||
""", (hash_key,))
|
||||
""",
|
||||
(hash_key,),
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
@@ -226,10 +236,13 @@ class EmbeddingCache:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
if model:
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM embeddings
|
||||
WHERE model = ?
|
||||
""", (model,))
|
||||
""",
|
||||
(model,),
|
||||
)
|
||||
else:
|
||||
cursor.execute("DELETE FROM embeddings")
|
||||
|
||||
@@ -249,10 +262,13 @@ class EmbeddingCache:
|
||||
|
||||
cutoff = (datetime.utcnow() - timedelta(days=self.ttl_days)).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM embeddings
|
||||
WHERE created_at < ?
|
||||
""", (cutoff,))
|
||||
""",
|
||||
(cutoff,),
|
||||
)
|
||||
|
||||
deleted = cursor.rowcount
|
||||
self.conn.commit()
|
||||
@@ -300,17 +316,19 @@ class EmbeddingCache:
|
||||
LIMIT 10
|
||||
""")
|
||||
top_accessed = [
|
||||
{"hash": row[0], "model": row[1], "access_count": row[2]}
|
||||
for row in cursor.fetchall()
|
||||
{"hash": row[0], "model": row[1], "access_count": row[2]} for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
# Expired entries
|
||||
cutoff = (datetime.utcnow() - timedelta(days=self.ttl_days)).isoformat()
|
||||
cursor.execute("""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM embeddings
|
||||
WHERE created_at < ?
|
||||
""", (cutoff,))
|
||||
""",
|
||||
(cutoff,),
|
||||
)
|
||||
expired = cursor.fetchone()[0]
|
||||
|
||||
return {
|
||||
@@ -318,7 +336,7 @@ class EmbeddingCache:
|
||||
"by_model": by_model,
|
||||
"top_accessed": top_accessed,
|
||||
"expired": expired,
|
||||
"ttl_days": self.ttl_days
|
||||
"ttl_days": self.ttl_days,
|
||||
}
|
||||
|
||||
def close(self):
|
||||
|
||||
@@ -9,6 +9,7 @@ import numpy as np
|
||||
# OpenAI support
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
OPENAI_AVAILABLE = False
|
||||
@@ -16,6 +17,7 @@ except ImportError:
|
||||
# Sentence transformers support
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
SENTENCE_TRANSFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SENTENCE_TRANSFORMERS_AVAILABLE = False
|
||||
@@ -23,6 +25,7 @@ except ImportError:
|
||||
# Voyage AI support (recommended by Anthropic for embeddings)
|
||||
try:
|
||||
import voyageai
|
||||
|
||||
VOYAGE_AVAILABLE = True
|
||||
except ImportError:
|
||||
VOYAGE_AVAILABLE = False
|
||||
@@ -129,7 +132,7 @@ class EmbeddingGenerator:
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
voyage_api_key: str | None = None,
|
||||
cache_dir: str | None = None
|
||||
cache_dir: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize embedding generator.
|
||||
@@ -162,8 +165,7 @@ class EmbeddingGenerator:
|
||||
"""Get information about a model."""
|
||||
if model not in self.MODELS:
|
||||
raise ValueError(
|
||||
f"Unknown model: {model}. "
|
||||
f"Available models: {', '.join(self.MODELS.keys())}"
|
||||
f"Unknown model: {model}. Available models: {', '.join(self.MODELS.keys())}"
|
||||
)
|
||||
return self.MODELS[model]
|
||||
|
||||
@@ -171,20 +173,19 @@ class EmbeddingGenerator:
|
||||
"""List all available models."""
|
||||
models = []
|
||||
for name, info in self.MODELS.items():
|
||||
models.append({
|
||||
"name": name,
|
||||
"provider": info["provider"],
|
||||
"dimensions": info["dimensions"],
|
||||
"max_tokens": info["max_tokens"],
|
||||
"cost_per_million": info.get("cost_per_million", 0.0),
|
||||
})
|
||||
models.append(
|
||||
{
|
||||
"name": name,
|
||||
"provider": info["provider"],
|
||||
"dimensions": info["dimensions"],
|
||||
"max_tokens": info["max_tokens"],
|
||||
"cost_per_million": info.get("cost_per_million", 0.0),
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text: str,
|
||||
model: str = "text-embedding-3-small",
|
||||
normalize: bool = True
|
||||
self, text: str, model: str = "text-embedding-3-small", normalize: bool = True
|
||||
) -> list[float]:
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
@@ -218,7 +219,7 @@ class EmbeddingGenerator:
|
||||
texts: list[str],
|
||||
model: str = "text-embedding-3-small",
|
||||
normalize: bool = True,
|
||||
batch_size: int = 32
|
||||
batch_size: int = 32,
|
||||
) -> tuple[list[list[float]], int]:
|
||||
"""
|
||||
Generate embeddings for multiple texts.
|
||||
@@ -248,24 +249,18 @@ class EmbeddingGenerator:
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
def _generate_openai(
|
||||
self, text: str, model: str, normalize: bool
|
||||
) -> list[float]:
|
||||
def _generate_openai(self, text: str, model: str, normalize: bool) -> list[float]:
|
||||
"""Generate embedding using OpenAI API."""
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenAI is required for OpenAI embeddings. "
|
||||
"Install with: pip install openai"
|
||||
"OpenAI is required for OpenAI embeddings. Install with: pip install openai"
|
||||
)
|
||||
|
||||
if not self.openai_client:
|
||||
raise ValueError("OpenAI API key not provided")
|
||||
|
||||
try:
|
||||
response = self.openai_client.embeddings.create(
|
||||
input=text,
|
||||
model=model
|
||||
)
|
||||
response = self.openai_client.embeddings.create(input=text, model=model)
|
||||
embedding = response.data[0].embedding
|
||||
|
||||
if normalize:
|
||||
@@ -281,8 +276,7 @@ class EmbeddingGenerator:
|
||||
"""Generate embeddings using OpenAI API in batches."""
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenAI is required for OpenAI embeddings. "
|
||||
"Install with: pip install openai"
|
||||
"OpenAI is required for OpenAI embeddings. Install with: pip install openai"
|
||||
)
|
||||
|
||||
if not self.openai_client:
|
||||
@@ -292,13 +286,10 @@ class EmbeddingGenerator:
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
batch = texts[i : i + batch_size]
|
||||
|
||||
try:
|
||||
response = self.openai_client.embeddings.create(
|
||||
input=batch,
|
||||
model=model
|
||||
)
|
||||
response = self.openai_client.embeddings.create(input=batch, model=model)
|
||||
|
||||
batch_embeddings = [item.embedding for item in response.data]
|
||||
|
||||
@@ -313,24 +304,18 @@ class EmbeddingGenerator:
|
||||
dimensions = len(all_embeddings[0]) if all_embeddings else 0
|
||||
return all_embeddings, dimensions
|
||||
|
||||
def _generate_voyage(
|
||||
self, text: str, model: str, normalize: bool
|
||||
) -> list[float]:
|
||||
def _generate_voyage(self, text: str, model: str, normalize: bool) -> list[float]:
|
||||
"""Generate embedding using Voyage AI API."""
|
||||
if not VOYAGE_AVAILABLE:
|
||||
raise ImportError(
|
||||
"voyageai is required for Voyage AI embeddings. "
|
||||
"Install with: pip install voyageai"
|
||||
"voyageai is required for Voyage AI embeddings. Install with: pip install voyageai"
|
||||
)
|
||||
|
||||
if not self.voyage_client:
|
||||
raise ValueError("Voyage API key not provided")
|
||||
|
||||
try:
|
||||
result = self.voyage_client.embed(
|
||||
texts=[text],
|
||||
model=model
|
||||
)
|
||||
result = self.voyage_client.embed(texts=[text], model=model)
|
||||
embedding = result.embeddings[0]
|
||||
|
||||
if normalize:
|
||||
@@ -346,8 +331,7 @@ class EmbeddingGenerator:
|
||||
"""Generate embeddings using Voyage AI API in batches."""
|
||||
if not VOYAGE_AVAILABLE:
|
||||
raise ImportError(
|
||||
"voyageai is required for Voyage AI embeddings. "
|
||||
"Install with: pip install voyageai"
|
||||
"voyageai is required for Voyage AI embeddings. Install with: pip install voyageai"
|
||||
)
|
||||
|
||||
if not self.voyage_client:
|
||||
@@ -357,13 +341,10 @@ class EmbeddingGenerator:
|
||||
|
||||
# Process in batches (Voyage AI supports up to 128 texts per request)
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
batch = texts[i : i + batch_size]
|
||||
|
||||
try:
|
||||
result = self.voyage_client.embed(
|
||||
texts=batch,
|
||||
model=model
|
||||
)
|
||||
result = self.voyage_client.embed(texts=batch, model=model)
|
||||
|
||||
batch_embeddings = result.embeddings
|
||||
|
||||
@@ -378,9 +359,7 @@ class EmbeddingGenerator:
|
||||
dimensions = len(all_embeddings[0]) if all_embeddings else 0
|
||||
return all_embeddings, dimensions
|
||||
|
||||
def _generate_sentence_transformer(
|
||||
self, text: str, model: str, normalize: bool
|
||||
) -> list[float]:
|
||||
def _generate_sentence_transformer(self, text: str, model: str, normalize: bool) -> list[float]:
|
||||
"""Generate embedding using sentence-transformers."""
|
||||
if not SENTENCE_TRANSFORMERS_AVAILABLE:
|
||||
raise ImportError(
|
||||
@@ -417,10 +396,7 @@ class EmbeddingGenerator:
|
||||
|
||||
# Generate embeddings in batches
|
||||
embeddings = st_model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
normalize_embeddings=normalize,
|
||||
show_progress_bar=False
|
||||
texts, batch_size=batch_size, normalize_embeddings=normalize, show_progress_bar=False
|
||||
)
|
||||
|
||||
dimensions = len(embeddings[0]) if len(embeddings) > 0 else 0
|
||||
|
||||
@@ -14,20 +14,14 @@ class EmbeddingRequest(BaseModel):
|
||||
"example": {
|
||||
"text": "This is a test document about Python programming.",
|
||||
"model": "text-embedding-3-small",
|
||||
"normalize": True
|
||||
"normalize": True,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
text: str = Field(..., description="Text to generate embedding for")
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Embedding model to use"
|
||||
)
|
||||
normalize: bool = Field(
|
||||
default=True,
|
||||
description="Normalize embeddings to unit length"
|
||||
)
|
||||
model: str = Field(default="text-embedding-3-small", description="Embedding model to use")
|
||||
normalize: bool = Field(default=True, description="Normalize embeddings to unit length")
|
||||
|
||||
|
||||
class BatchEmbeddingRequest(BaseModel):
|
||||
@@ -39,27 +33,20 @@ class BatchEmbeddingRequest(BaseModel):
|
||||
"texts": [
|
||||
"First document about Python",
|
||||
"Second document about JavaScript",
|
||||
"Third document about Rust"
|
||||
"Third document about Rust",
|
||||
],
|
||||
"model": "text-embedding-3-small",
|
||||
"normalize": True,
|
||||
"batch_size": 32
|
||||
"batch_size": 32,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
texts: list[str] = Field(..., description="List of texts to embed")
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Embedding model to use"
|
||||
)
|
||||
normalize: bool = Field(
|
||||
default=True,
|
||||
description="Normalize embeddings to unit length"
|
||||
)
|
||||
model: str = Field(default="text-embedding-3-small", description="Embedding model to use")
|
||||
normalize: bool = Field(default=True, description="Normalize embeddings to unit length")
|
||||
batch_size: int | None = Field(
|
||||
default=32,
|
||||
description="Batch size for processing (default: 32)"
|
||||
default=32, description="Batch size for processing (default: 32)"
|
||||
)
|
||||
|
||||
|
||||
@@ -69,10 +56,7 @@ class EmbeddingResponse(BaseModel):
|
||||
embedding: list[float] = Field(..., description="Generated embedding vector")
|
||||
model: str = Field(..., description="Model used for generation")
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
cached: bool = Field(
|
||||
default=False,
|
||||
description="Whether embedding was retrieved from cache"
|
||||
)
|
||||
cached: bool = Field(default=False, description="Whether embedding was retrieved from cache")
|
||||
|
||||
|
||||
class BatchEmbeddingResponse(BaseModel):
|
||||
@@ -82,10 +66,7 @@ class BatchEmbeddingResponse(BaseModel):
|
||||
model: str = Field(..., description="Model used for generation")
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
count: int = Field(..., description="Number of embeddings generated")
|
||||
cached_count: int = Field(
|
||||
default=0,
|
||||
description="Number of embeddings retrieved from cache"
|
||||
)
|
||||
cached_count: int = Field(default=0, description="Number of embeddings retrieved from cache")
|
||||
|
||||
|
||||
class SkillEmbeddingRequest(BaseModel):
|
||||
@@ -97,24 +78,15 @@ class SkillEmbeddingRequest(BaseModel):
|
||||
"skill_path": "/path/to/skill/react",
|
||||
"model": "text-embedding-3-small",
|
||||
"chunk_size": 512,
|
||||
"overlap": 50
|
||||
"overlap": 50,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
skill_path: str = Field(..., description="Path to skill directory")
|
||||
model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Embedding model to use"
|
||||
)
|
||||
chunk_size: int = Field(
|
||||
default=512,
|
||||
description="Chunk size for splitting documents (tokens)"
|
||||
)
|
||||
overlap: int = Field(
|
||||
default=50,
|
||||
description="Overlap between chunks (tokens)"
|
||||
)
|
||||
model: str = Field(default="text-embedding-3-small", description="Embedding model to use")
|
||||
chunk_size: int = Field(default=512, description="Chunk size for splitting documents (tokens)")
|
||||
overlap: int = Field(default=50, description="Overlap between chunks (tokens)")
|
||||
|
||||
|
||||
class SkillEmbeddingResponse(BaseModel):
|
||||
@@ -124,10 +96,7 @@ class SkillEmbeddingResponse(BaseModel):
|
||||
total_chunks: int = Field(..., description="Total number of chunks embedded")
|
||||
model: str = Field(..., description="Model used for generation")
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Skill metadata"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Skill metadata")
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
@@ -144,12 +113,13 @@ class ModelInfo(BaseModel):
|
||||
"""Information about an embedding model."""
|
||||
|
||||
name: str = Field(..., description="Model name")
|
||||
provider: str = Field(..., description="Model provider (openai, anthropic, sentence-transformers)")
|
||||
provider: str = Field(
|
||||
..., description="Model provider (openai, anthropic, sentence-transformers)"
|
||||
)
|
||||
dimensions: int = Field(..., description="Embedding dimensions")
|
||||
max_tokens: int = Field(..., description="Maximum input tokens")
|
||||
cost_per_million: float | None = Field(
|
||||
None,
|
||||
description="Cost per million tokens (if applicable)"
|
||||
None, description="Cost per million tokens (if applicable)"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ try:
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
FASTAPI_AVAILABLE = True
|
||||
except ImportError:
|
||||
FASTAPI_AVAILABLE = False
|
||||
@@ -51,7 +52,7 @@ if FASTAPI_AVAILABLE:
|
||||
description="Generate embeddings for text and skill content",
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc"
|
||||
redoc_url="/redoc",
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
@@ -64,13 +65,14 @@ if FASTAPI_AVAILABLE:
|
||||
)
|
||||
|
||||
# Initialize generator and cache
|
||||
cache_dir = os.getenv("EMBEDDING_CACHE_DIR", os.path.expanduser("~/.cache/skill-seekers/embeddings"))
|
||||
cache_dir = os.getenv(
|
||||
"EMBEDDING_CACHE_DIR", os.path.expanduser("~/.cache/skill-seekers/embeddings")
|
||||
)
|
||||
cache_db = os.path.join(cache_dir, "embeddings.db")
|
||||
cache_enabled = os.getenv("EMBEDDING_CACHE_ENABLED", "true").lower() == "true"
|
||||
|
||||
generator = EmbeddingGenerator(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
voyage_api_key=os.getenv("VOYAGE_API_KEY")
|
||||
api_key=os.getenv("OPENAI_API_KEY"), voyage_api_key=os.getenv("VOYAGE_API_KEY")
|
||||
)
|
||||
cache = EmbeddingCache(cache_db) if cache_enabled else None
|
||||
|
||||
@@ -81,7 +83,7 @@ if FASTAPI_AVAILABLE:
|
||||
"service": "Skill Seekers Embedding API",
|
||||
"version": "1.0.0",
|
||||
"docs": "/docs",
|
||||
"health": "/health"
|
||||
"health": "/health",
|
||||
}
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
@@ -95,7 +97,7 @@ if FASTAPI_AVAILABLE:
|
||||
version="1.0.0",
|
||||
models=models,
|
||||
cache_enabled=cache_enabled,
|
||||
cache_size=cache_size
|
||||
cache_size=cache_size,
|
||||
)
|
||||
|
||||
@app.get("/models", response_model=ModelsResponse)
|
||||
@@ -109,15 +111,12 @@ if FASTAPI_AVAILABLE:
|
||||
provider=m["provider"],
|
||||
dimensions=m["dimensions"],
|
||||
max_tokens=m["max_tokens"],
|
||||
cost_per_million=m.get("cost_per_million")
|
||||
cost_per_million=m.get("cost_per_million"),
|
||||
)
|
||||
for m in models_list
|
||||
]
|
||||
|
||||
return ModelsResponse(
|
||||
models=model_infos,
|
||||
count=len(model_infos)
|
||||
)
|
||||
return ModelsResponse(models=model_infos, count=len(model_infos))
|
||||
|
||||
@app.post("/embed", response_model=EmbeddingResponse)
|
||||
async def embed_text(request: EmbeddingRequest):
|
||||
@@ -144,9 +143,7 @@ if FASTAPI_AVAILABLE:
|
||||
else:
|
||||
# Generate embedding
|
||||
embedding = generator.generate(
|
||||
request.text,
|
||||
model=request.model,
|
||||
normalize=request.normalize
|
||||
request.text, model=request.model, normalize=request.normalize
|
||||
)
|
||||
|
||||
# Store in cache
|
||||
@@ -154,10 +151,7 @@ if FASTAPI_AVAILABLE:
|
||||
cache.set(hash_key, embedding, request.model)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embedding=embedding,
|
||||
model=request.model,
|
||||
dimensions=len(embedding),
|
||||
cached=cached
|
||||
embedding=embedding, model=request.model, dimensions=len(embedding), cached=cached
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -202,11 +196,13 @@ if FASTAPI_AVAILABLE:
|
||||
texts_to_generate,
|
||||
model=request.model,
|
||||
normalize=request.normalize,
|
||||
batch_size=request.batch_size
|
||||
batch_size=request.batch_size,
|
||||
)
|
||||
|
||||
# Fill in placeholders and cache
|
||||
for idx, text, embedding in zip(text_indices, texts_to_generate, generated_embeddings, strict=False):
|
||||
for idx, text, embedding in zip(
|
||||
text_indices, texts_to_generate, generated_embeddings, strict=False
|
||||
):
|
||||
embeddings[idx] = embedding
|
||||
|
||||
if cache:
|
||||
@@ -220,7 +216,7 @@ if FASTAPI_AVAILABLE:
|
||||
model=request.model,
|
||||
dimensions=dimensions,
|
||||
count=len(embeddings),
|
||||
cached_count=cached_count
|
||||
cached_count=cached_count,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -244,12 +240,16 @@ if FASTAPI_AVAILABLE:
|
||||
skill_path = Path(request.skill_path)
|
||||
|
||||
if not skill_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Skill path not found: {request.skill_path}")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Skill path not found: {request.skill_path}"
|
||||
)
|
||||
|
||||
# Read SKILL.md
|
||||
skill_md = skill_path / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
raise HTTPException(status_code=404, detail=f"SKILL.md not found in {request.skill_path}")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"SKILL.md not found in {request.skill_path}"
|
||||
)
|
||||
|
||||
skill_content = skill_md.read_text()
|
||||
|
||||
@@ -262,10 +262,7 @@ if FASTAPI_AVAILABLE:
|
||||
|
||||
# Generate embeddings for chunks
|
||||
embeddings, dimensions = generator.generate_batch(
|
||||
chunks,
|
||||
model=request.model,
|
||||
normalize=True,
|
||||
batch_size=32
|
||||
chunks, model=request.model, normalize=True, batch_size=32
|
||||
)
|
||||
|
||||
# TODO: Store embeddings in vector database
|
||||
@@ -279,8 +276,8 @@ if FASTAPI_AVAILABLE:
|
||||
metadata={
|
||||
"skill_path": str(skill_path),
|
||||
"chunks": len(chunks),
|
||||
"content_length": len(skill_content)
|
||||
}
|
||||
"content_length": len(skill_content),
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -298,7 +295,7 @@ if FASTAPI_AVAILABLE:
|
||||
|
||||
@app.post("/cache/clear", response_model=dict)
|
||||
async def clear_cache(
|
||||
model: str | None = Query(None, description="Model to clear (all if not specified)")
|
||||
model: str | None = Query(None, description="Model to clear (all if not specified)"),
|
||||
):
|
||||
"""Clear cache entries."""
|
||||
if not cache:
|
||||
@@ -306,11 +303,7 @@ if FASTAPI_AVAILABLE:
|
||||
|
||||
deleted = cache.clear(model=model)
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"deleted": deleted,
|
||||
"model": model or "all"
|
||||
}
|
||||
return {"status": "ok", "deleted": deleted, "model": model or "all"}
|
||||
|
||||
@app.post("/cache/clear-expired", response_model=dict)
|
||||
async def clear_expired():
|
||||
@@ -320,10 +313,7 @@ if FASTAPI_AVAILABLE:
|
||||
|
||||
deleted = cache.clear_expired()
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"deleted": deleted
|
||||
}
|
||||
return {"status": "ok", "deleted": deleted}
|
||||
|
||||
else:
|
||||
print("Error: FastAPI not available. Install with: pip install fastapi uvicorn")
|
||||
@@ -348,12 +338,7 @@ def main():
|
||||
if cache_enabled:
|
||||
print(f"💾 Cache database: {cache_db}")
|
||||
|
||||
uvicorn.run(
|
||||
"skill_seekers.embedding.server:app",
|
||||
host=host,
|
||||
port=port,
|
||||
reload=reload
|
||||
)
|
||||
uvicorn.run("skill_seekers.embedding.server:app", host=host, port=port, reload=reload)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user