style: apply ruff format to all source and test files

Fixes ruff format --check CI failure. 22 files reformatted to satisfy
the ruff formatter's style requirements. No logic changes, only
whitespace/formatting adjustments.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
yusyus
2026-02-18 22:50:05 +03:00
parent 0878ad3ef6
commit 4b89e0a015
22 changed files with 707 additions and 597 deletions

View File

@@ -485,8 +485,7 @@ def extract_rst_structure(content: str) -> dict[str, Any]:
structure = {
"title": doc.title,
"headers": [
{"level": h.level, "text": h.text, "line": h.source_line}
for h in doc.headings
{"level": h.level, "text": h.text, "line": h.source_line} for h in doc.headings
],
"code_blocks": [
{
@@ -508,12 +507,10 @@ def extract_rst_structure(content: str) -> dict[str, Any]:
for t in doc.tables
],
"links": [
{"text": x.text or x.target, "url": x.target}
for x in doc.external_links
{"text": x.text or x.target, "url": x.target} for x in doc.external_links
],
"cross_references": [
{"type": x.ref_type.value, "target": x.target}
for x in doc.internal_links
{"type": x.ref_type.value, "target": x.target} for x in doc.internal_links
],
"word_count": len(content.split()),
"line_count": len(content.split("\n")),
@@ -569,7 +566,9 @@ def extract_rst_structure(content: str) -> dict[str, Any]:
structure["title"] = text
# Basic code block extraction
code_block_pattern = re.compile(r"\.\.\s+code-block::\s+(\w+)\s*\n\s+(.*?)(?=\n\S|\Z)", re.DOTALL)
code_block_pattern = re.compile(
r"\.\.\s+code-block::\s+(\w+)\s*\n\s+(.*?)(?=\n\S|\Z)", re.DOTALL
)
for match in code_block_pattern.finditer(content):
language = match.group(1) or "text"
code = match.group(2).strip()
@@ -585,9 +584,7 @@ def extract_rst_structure(content: str) -> dict[str, Any]:
# Basic link extraction
link_pattern = re.compile(r"`([^<`]+)\s+<([^>]+)>`_")
for match in link_pattern.finditer(content):
structure["links"].append(
{"text": match.group(1).strip(), "url": match.group(2)}
)
structure["links"].append({"text": match.group(1).strip(), "url": match.group(2)})
return structure
@@ -729,8 +726,12 @@ def process_markdown_docs(
],
"tables": len(parsed_doc.tables),
"cross_refs": len(parsed_doc.internal_links),
"directives": len([b for b in parsed_doc.blocks if b.type.value == "admonition"]),
"word_count": parsed_doc.stats.total_blocks if parsed_doc.stats else 0,
"directives": len(
[b for b in parsed_doc.blocks if b.type.value == "admonition"]
),
"word_count": parsed_doc.stats.total_blocks
if parsed_doc.stats
else 0,
"line_count": len(content.split("\n")),
}
else:
@@ -752,7 +753,9 @@ def process_markdown_docs(
"tables": len(parsed_doc.tables),
"images": len(parsed_doc.images),
"links": len(parsed_doc.external_links),
"word_count": parsed_doc.stats.total_blocks if parsed_doc.stats else 0,
"word_count": parsed_doc.stats.total_blocks
if parsed_doc.stats
else 0,
"line_count": len(content.split("\n")),
}
except ImportError:
@@ -789,10 +792,15 @@ def process_markdown_docs(
"tables": len(parsed_doc.tables),
"cross_references": len(parsed_doc.internal_links),
"code_blocks": len(parsed_doc.code_blocks),
"images": len(getattr(parsed_doc, 'images', [])),
"images": len(getattr(parsed_doc, "images", [])),
"quality_scores": {
"avg_code_quality": sum(cb.quality_score or 0 for cb in parsed_doc.code_blocks) / len(parsed_doc.code_blocks) if parsed_doc.code_blocks else 0,
}
"avg_code_quality": sum(
cb.quality_score or 0 for cb in parsed_doc.code_blocks
)
/ len(parsed_doc.code_blocks)
if parsed_doc.code_blocks
else 0,
},
}
processed_docs.append(doc_data)
@@ -850,8 +858,12 @@ def process_markdown_docs(
enhanced_count = sum(1 for doc in processed_docs if doc.get("_enhanced", False))
if enhanced_count > 0:
total_tables = sum(doc.get("parsed_data", {}).get("tables", 0) for doc in processed_docs)
total_xrefs = sum(doc.get("parsed_data", {}).get("cross_references", 0) for doc in processed_docs)
total_code_blocks = sum(doc.get("parsed_data", {}).get("code_blocks", 0) for doc in processed_docs)
total_xrefs = sum(
doc.get("parsed_data", {}).get("cross_references", 0) for doc in processed_docs
)
total_code_blocks = sum(
doc.get("parsed_data", {}).get("code_blocks", 0) for doc in processed_docs
)
extraction_summary = {
"enhanced_files": enhanced_count,

View File

@@ -426,8 +426,7 @@ class DocToSkillConverter:
"url": url,
"title": doc.title or "",
"content": "\n\n".join(
p for p in doc._extract_content_text().split("\n\n")
if len(p.strip()) >= 20
p for p in doc._extract_content_text().split("\n\n") if len(p.strip()) >= 20
),
"headings": [
{"level": f"h{h.level}", "text": h.text, "id": h.id or ""}
@@ -2309,9 +2308,7 @@ def execute_enhancement(config: dict[str, Any], args: argparse.Namespace, conver
# Check if workflow was already executed (for logging context)
workflow_executed = (
converter
and hasattr(converter, 'workflow_executed')
and converter.workflow_executed
converter and hasattr(converter, "workflow_executed") and converter.workflow_executed
)
workflow_name = converter.workflow_name if workflow_executed else None
@@ -2328,7 +2325,9 @@ def execute_enhancement(config: dict[str, Any], args: argparse.Namespace, conver
logger.info("=" * 80)
if workflow_executed:
logger.info(f" Running after workflow: {workflow_name}")
logger.info(" (Workflow provides specialized analysis, enhancement provides general improvements)")
logger.info(
" (Workflow provides specialized analysis, enhancement provides general improvements)"
)
logger.info("")
try:

View File

@@ -197,9 +197,7 @@ class WorkflowEngine:
extends=data.get("extends"),
)
def _merge_workflows(
self, parent: EnhancementWorkflow, child_data: dict
) -> dict:
def _merge_workflows(self, parent: EnhancementWorkflow, child_data: dict) -> dict:
"""Merge child workflow with parent (inheritance)."""
# Start with parent as dict
merged = {
@@ -239,12 +237,8 @@ class WorkflowEngine:
parent_post = parent.post_process
child_post = child_data.get("post_process", {})
merged["post_process"] = {
"remove_sections": child_post.get(
"remove_sections", parent_post.remove_sections
),
"reorder_sections": child_post.get(
"reorder_sections", parent_post.reorder_sections
),
"remove_sections": child_post.get("remove_sections", parent_post.remove_sections),
"reorder_sections": child_post.get("reorder_sections", parent_post.reorder_sections),
"add_metadata": {
**parent_post.add_metadata,
**child_post.get("add_metadata", {}),
@@ -285,9 +279,7 @@ class WorkflowEngine:
logger.info(f"🔄 Running stage {idx}/{len(self.workflow.stages)}: {stage.name}")
# Build stage context
stage_context = self._build_stage_context(
stage, current_results, context
)
stage_context = self._build_stage_context(stage, current_results, context)
# Run stage
try:
@@ -408,9 +400,7 @@ class WorkflowEngine:
return result
def _merge_stage_results(
self, current: dict, stage_results: dict, target: str
) -> dict:
def _merge_stage_results(self, current: dict, stage_results: dict, target: str) -> dict:
"""Merge stage results into current results."""
if target == "all":
# Merge everything

View File

@@ -1454,7 +1454,9 @@ def main():
logger.info("=" * 80)
if workflow_executed:
logger.info(f" Running after workflow: {workflow_name}")
logger.info(" (Workflow provides specialized analysis, enhancement provides general improvements)")
logger.info(
" (Workflow provides specialized analysis, enhancement provides general improvements)"
)
logger.info("")
if api_key:
@@ -1491,7 +1493,9 @@ def main():
logger.info(f" skill-seekers enhance {skill_dir}/ --enhance-level 2")
logger.info(" (auto-detects API vs LOCAL mode based on ANTHROPIC_API_KEY)")
logger.info("\n💡 Or use a workflow:")
logger.info(f" skill-seekers github --repo {config['repo']} --enhance-workflow architecture-comprehensive")
logger.info(
f" skill-seekers github --repo {config['repo']} --enhance-workflow architecture-comprehensive"
)
logger.info(f"\nNext step: skill-seekers package {skill_dir}/")

View File

@@ -20,6 +20,7 @@ logger = logging.getLogger(__name__)
@dataclass
class ParseResult:
"""Result of parsing a document."""
document: Document | None = None
success: bool = False
errors: list[str] = field(default_factory=list)
@@ -56,11 +57,11 @@ class BaseParser(ABC):
- encoding: str = 'utf-8'
"""
self.options = options or {}
self._include_comments = self.options.get('include_comments', False)
self._extract_metadata = self.options.get('extract_metadata', True)
self._quality_scoring = self.options.get('quality_scoring', True)
self._max_file_size = self.options.get('max_file_size_mb', 50.0) * 1024 * 1024
self._encoding = self.options.get('encoding', 'utf-8')
self._include_comments = self.options.get("include_comments", False)
self._extract_metadata = self.options.get("extract_metadata", True)
self._quality_scoring = self.options.get("quality_scoring", True)
self._max_file_size = self.options.get("max_file_size_mb", 50.0) * 1024 * 1024
self._encoding = self.options.get("encoding", "utf-8")
@property
@abstractmethod
@@ -149,15 +150,19 @@ class BaseParser(ABC):
def parse_string(self, content: str, source_path: str = "<string>") -> ParseResult:
"""Parse content from string."""
# Create a wrapper that looks like a path
class StringSource:
def __init__(self, content: str, path: str):
self._content = content
self._path = path
def read_text(self, encoding: str = 'utf-8') -> str:
def read_text(self, encoding: str = "utf-8") -> str:
return self._content
def exists(self) -> bool:
return True
def __str__(self):
return self._path
@@ -238,17 +243,20 @@ class BaseParser(ABC):
document.stats.code_blocks = len(document.code_blocks)
document.stats.tables = len(document.tables)
document.stats.headings = len(document.headings)
document.stats.cross_references = len(document.internal_links) + len(document.external_links)
document.stats.cross_references = len(document.internal_links) + len(
document.external_links
)
return document
def _extract_headings(self, document: Document) -> list:
"""Extract headings from content blocks."""
from .unified_structure import ContentBlockType
headings = []
for block in document.blocks:
if block.type == ContentBlockType.HEADING:
heading_data = block.metadata.get('heading_data')
heading_data = block.metadata.get("heading_data")
if heading_data:
headings.append(heading_data)
return headings
@@ -257,22 +265,23 @@ class BaseParser(ABC):
"""Extract code blocks from content blocks."""
code_blocks = []
for block in document.blocks:
if block.metadata.get('code_data'):
code_blocks.append(block.metadata['code_data'])
if block.metadata.get("code_data"):
code_blocks.append(block.metadata["code_data"])
return code_blocks
def _extract_tables(self, document: Document) -> list:
"""Extract tables from content blocks."""
tables = []
for block in document.blocks:
if block.metadata.get('table_data'):
tables.append(block.metadata['table_data'])
if block.metadata.get("table_data"):
tables.append(block.metadata["table_data"])
return tables
def _create_quality_scorer(self):
"""Create a quality scorer if enabled."""
if self._quality_scoring:
from .quality_scorer import QualityScorer
return QualityScorer()
return None
@@ -292,12 +301,14 @@ def get_parser_for_file(path: str | Path) -> BaseParser | None:
# Try RST parser
from .rst_parser import RstParser
rst_parser = RstParser()
if suffix in rst_parser.supported_extensions:
return rst_parser
# Try Markdown parser
from .markdown_parser import MarkdownParser
md_parser = MarkdownParser()
if suffix in md_parser.supported_extensions:
return md_parser
@@ -320,11 +331,13 @@ def parse_document(source: str | Path, format_hint: str | None = None) -> ParseR
"""
# Use format hint if provided
if format_hint:
if format_hint.lower() in ('rst', 'rest', 'restructuredtext'):
if format_hint.lower() in ("rst", "rest", "restructuredtext"):
from .rst_parser import RstParser
return RstParser().parse(source)
elif format_hint.lower() in ('md', 'markdown'):
elif format_hint.lower() in ("md", "markdown"):
from .markdown_parser import MarkdownParser
return MarkdownParser().parse(source)
# Auto-detect from file extension
@@ -336,11 +349,13 @@ def parse_document(source: str | Path, format_hint: str | None = None) -> ParseR
content = source if isinstance(source, str) else Path(source).read_text()
# Check for RST indicators
rst_indicators = ['.. ', '::\n', ':ref:`', '.. toctree::', '.. code-block::']
rst_indicators = [".. ", "::\n", ":ref:`", ".. toctree::", ".. code-block::"]
if any(ind in content for ind in rst_indicators):
from .rst_parser import RstParser
return RstParser().parse_string(content)
# Default to Markdown
from .markdown_parser import MarkdownParser
return MarkdownParser().parse_string(content)

View File

@@ -7,7 +7,12 @@ Convert unified Document structure to various output formats.
from typing import Any
from .unified_structure import (
Document, ContentBlock, ContentBlockType, AdmonitionType, ListType, Table
Document,
ContentBlock,
ContentBlockType,
AdmonitionType,
ListType,
Table,
)
@@ -16,10 +21,10 @@ class MarkdownFormatter:
def __init__(self, options: dict[str, Any] = None):
self.options = options or {}
self.include_toc = self.options.get('include_toc', False)
self.max_heading_level = self.options.get('max_heading_level', 6)
self.code_block_style = self.options.get('code_block_style', 'fenced')
self.table_style = self.options.get('table_style', 'github')
self.include_toc = self.options.get("include_toc", False)
self.max_heading_level = self.options.get("max_heading_level", 6)
self.code_block_style = self.options.get("code_block_style", "fenced")
self.table_style = self.options.get("table_style", "github")
def format(self, document: Document) -> str:
"""Convert document to markdown string."""
@@ -43,11 +48,11 @@ class MarkdownFormatter:
if formatted:
parts.append(formatted)
return '\n'.join(parts)
return "\n".join(parts)
def _format_metadata(self, meta: dict) -> str:
"""Format metadata as YAML frontmatter."""
lines = ['---']
lines = ["---"]
for key, value in meta.items():
if isinstance(value, list):
lines.append(f"{key}:")
@@ -55,19 +60,19 @@ class MarkdownFormatter:
lines.append(f" - {item}")
else:
lines.append(f"{key}: {value}")
lines.append('---\n')
return '\n'.join(lines)
lines.append("---\n")
return "\n".join(lines)
def _format_toc(self, headings: list) -> str:
"""Format table of contents."""
lines = ['## Table of Contents\n']
lines = ["## Table of Contents\n"]
for h in headings:
if h.level <= self.max_heading_level:
indent = ' ' * (h.level - 1)
anchor = h.id or h.text.lower().replace(' ', '-')
indent = " " * (h.level - 1)
anchor = h.id or h.text.lower().replace(" ", "-")
lines.append(f"{indent}- [{h.text}](#{anchor})")
lines.append('')
return '\n'.join(lines)
lines.append("")
return "\n".join(lines)
def _format_block(self, block: ContentBlock) -> str:
"""Format a single content block."""
@@ -91,16 +96,16 @@ class MarkdownFormatter:
return handler(block)
# Default: return content as-is
return block.content + '\n'
return block.content + "\n"
def _format_heading(self, block: ContentBlock) -> str:
"""Format heading block."""
heading_data = block.metadata.get('heading_data')
heading_data = block.metadata.get("heading_data")
if heading_data:
level = min(heading_data.level, 6)
text = heading_data.text
else:
level = block.metadata.get('level', 1)
level = block.metadata.get("level", 1)
text = block.content
if level > self.max_heading_level:
@@ -110,38 +115,38 @@ class MarkdownFormatter:
def _format_paragraph(self, block: ContentBlock) -> str:
"""Format paragraph block."""
return block.content + '\n'
return block.content + "\n"
def _format_code_block(self, block: ContentBlock) -> str:
"""Format code block."""
code_data = block.metadata.get('code_data')
code_data = block.metadata.get("code_data")
if code_data:
code = code_data.code
lang = code_data.language or ''
lang = code_data.language or ""
else:
code = block.content
lang = block.metadata.get('language', '')
lang = block.metadata.get("language", "")
if self.code_block_style == 'fenced':
if self.code_block_style == "fenced":
return f"```{lang}\n{code}\n```\n"
else:
# Indented style
indented = '\n'.join(' ' + line for line in code.split('\n'))
return indented + '\n'
indented = "\n".join(" " + line for line in code.split("\n"))
return indented + "\n"
def _format_table(self, block: ContentBlock) -> str:
"""Format table block."""
table_data = block.metadata.get('table_data')
table_data = block.metadata.get("table_data")
if not table_data:
return ''
return ""
return self._format_table_data(table_data)
def _format_table_data(self, table: Table) -> str:
"""Format table data as markdown."""
if not table.rows:
return ''
return ""
lines = []
@@ -151,92 +156,92 @@ class MarkdownFormatter:
# Headers
headers = table.headers or table.rows[0]
lines.append('| ' + ' | '.join(headers) + ' |')
lines.append('|' + '|'.join('---' for _ in headers) + '|')
lines.append("| " + " | ".join(headers) + " |")
lines.append("|" + "|".join("---" for _ in headers) + "|")
# Rows (skip first if used as headers)
start_row = 0 if table.headers else 1
for row in table.rows[start_row:]:
# Pad row to match header count
padded_row = row + [''] * (len(headers) - len(row))
lines.append('| ' + ' | '.join(padded_row[:len(headers)]) + ' |')
padded_row = row + [""] * (len(headers) - len(row))
lines.append("| " + " | ".join(padded_row[: len(headers)]) + " |")
lines.append('')
return '\n'.join(lines)
lines.append("")
return "\n".join(lines)
def _format_list(self, block: ContentBlock) -> str:
"""Format list block."""
list_type = block.metadata.get('list_type', ListType.BULLET)
items = block.metadata.get('items', [])
list_type = block.metadata.get("list_type", ListType.BULLET)
items = block.metadata.get("items", [])
if not items:
return block.content + '\n'
return block.content + "\n"
lines = []
for i, item in enumerate(items):
prefix = f"{i + 1}." if list_type == ListType.NUMBERED else "-"
lines.append(f"{prefix} {item}")
lines.append('')
return '\n'.join(lines)
lines.append("")
return "\n".join(lines)
def _format_image(self, block: ContentBlock) -> str:
"""Format image block."""
image_data = block.metadata.get('image_data')
image_data = block.metadata.get("image_data")
if image_data:
src = image_data.source
alt = image_data.alt_text or ''
alt = image_data.alt_text or ""
else:
src = block.metadata.get('src', '')
alt = block.metadata.get('alt', '')
src = block.metadata.get("src", "")
alt = block.metadata.get("alt", "")
return f"![{alt}]({src})\n"
def _format_cross_ref(self, block: ContentBlock) -> str:
"""Format cross-reference block."""
xref_data = block.metadata.get('xref_data')
xref_data = block.metadata.get("xref_data")
if xref_data:
text = xref_data.text or xref_data.target
target = xref_data.target
return f"[{text}](#{target})\n"
return block.content + '\n'
return block.content + "\n"
def _format_admonition(self, block: ContentBlock) -> str:
"""Format admonition/callout block."""
admonition_type = block.metadata.get('admonition_type', AdmonitionType.NOTE)
admonition_type = block.metadata.get("admonition_type", AdmonitionType.NOTE)
# GitHub-style admonitions
type_map = {
AdmonitionType.NOTE: 'NOTE',
AdmonitionType.WARNING: 'WARNING',
AdmonitionType.TIP: 'TIP',
AdmonitionType.IMPORTANT: 'IMPORTANT',
AdmonitionType.CAUTION: 'CAUTION',
AdmonitionType.NOTE: "NOTE",
AdmonitionType.WARNING: "WARNING",
AdmonitionType.TIP: "TIP",
AdmonitionType.IMPORTANT: "IMPORTANT",
AdmonitionType.CAUTION: "CAUTION",
}
type_str = type_map.get(admonition_type, 'NOTE')
type_str = type_map.get(admonition_type, "NOTE")
content = block.content
return f"> [!{type_str}]\n> {content.replace(chr(10), chr(10) + '> ')}\n"
def _format_directive(self, block: ContentBlock) -> str:
"""Format directive block (RST-specific)."""
directive_name = block.metadata.get('directive_name', 'unknown')
directive_name = block.metadata.get("directive_name", "unknown")
# Format as a blockquote with directive name
content = block.content
lines = [f"> **{directive_name}**"]
for line in content.split('\n'):
for line in content.split("\n"):
lines.append(f"> {line}")
lines.append('')
return '\n'.join(lines)
lines.append("")
return "\n".join(lines)
def _format_field_list(self, block: ContentBlock) -> str:
"""Format field list block."""
fields = block.metadata.get('fields', [])
fields = block.metadata.get("fields", [])
if not fields:
return block.content + '\n'
return block.content + "\n"
lines = []
for field in fields:
@@ -244,14 +249,14 @@ class MarkdownFormatter:
lines.append(f"**{field.name}** (`{field.arg}`): {field.content}")
else:
lines.append(f"**{field.name}**: {field.content}")
lines.append('')
return '\n'.join(lines)
lines.append("")
return "\n".join(lines)
def _format_definition_list(self, block: ContentBlock) -> str:
"""Format definition list block."""
items = block.metadata.get('items', [])
items = block.metadata.get("items", [])
if not items:
return block.content + '\n'
return block.content + "\n"
lines = []
for item in items:
@@ -260,12 +265,12 @@ class MarkdownFormatter:
else:
lines.append(f"**{item.term}**")
lines.append(f": {item.definition}")
lines.append('')
return '\n'.join(lines)
lines.append("")
return "\n".join(lines)
def _format_meta(self, block: ContentBlock) -> str:
"""Format metadata block (usually filtered out)."""
return '' # Metadata goes in YAML frontmatter
return "" # Metadata goes in YAML frontmatter
class SkillFormatter:
@@ -278,10 +283,7 @@ class SkillFormatter:
"source_path": document.source_path,
"format": document.format,
"content_summary": self._extract_summary(document),
"headings": [
{"level": h.level, "text": h.text, "id": h.id}
for h in document.headings
],
"headings": [{"level": h.level, "text": h.text, "id": h.id} for h in document.headings],
"code_samples": [
{
"code": cb.code,
@@ -318,7 +320,7 @@ class SkillFormatter:
"headings": document.stats.headings,
"cross_references": document.stats.cross_references,
"processing_time_ms": document.stats.processing_time_ms,
}
},
}
def _extract_summary(self, document: Document, max_length: int = 500) -> str:
@@ -327,12 +329,12 @@ class SkillFormatter:
for block in document.blocks:
if block.type == ContentBlockType.PARAGRAPH:
paragraphs.append(block.content)
if len(' '.join(paragraphs)) > max_length:
if len(" ".join(paragraphs)) > max_length:
break
summary = ' '.join(paragraphs)
summary = " ".join(paragraphs)
if len(summary) > max_length:
summary = summary[:max_length - 3] + '...'
summary = summary[: max_length - 3] + "..."
return summary

View File

@@ -21,8 +21,17 @@ from typing import Any
from .base_parser import BaseParser
from .unified_structure import (
Document, ContentBlock, ContentBlockType, CrossReference, CrossRefType,
AdmonitionType, Heading, CodeBlock, Table, Image, ListType
Document,
ContentBlock,
ContentBlockType,
CrossReference,
CrossRefType,
AdmonitionType,
Heading,
CodeBlock,
Table,
Image,
ListType,
)
from .quality_scorer import QualityScorer
@@ -36,14 +45,14 @@ class MarkdownParser(BaseParser):
# Admonition types for GitHub-style callouts
ADMONITION_TYPES = {
'note': AdmonitionType.NOTE,
'warning': AdmonitionType.WARNING,
'tip': AdmonitionType.TIP,
'hint': AdmonitionType.HINT,
'important': AdmonitionType.IMPORTANT,
'caution': AdmonitionType.CAUTION,
'danger': AdmonitionType.DANGER,
'attention': AdmonitionType.ATTENTION,
"note": AdmonitionType.NOTE,
"warning": AdmonitionType.WARNING,
"tip": AdmonitionType.TIP,
"hint": AdmonitionType.HINT,
"important": AdmonitionType.IMPORTANT,
"caution": AdmonitionType.CAUTION,
"danger": AdmonitionType.DANGER,
"attention": AdmonitionType.ATTENTION,
}
def __init__(self, options: dict[str, Any] | None = None):
@@ -54,32 +63,32 @@ class MarkdownParser(BaseParser):
@property
def format_name(self) -> str:
return 'markdown'
return "markdown"
@property
def supported_extensions(self) -> list[str]:
return ['.md', '.markdown', '.mdown', '.mkd']
return [".md", ".markdown", ".mdown", ".mkd"]
def _detect_format(self, content: str) -> bool:
"""Detect if content is Markdown."""
md_indicators = [
r'^#{1,6}\s+\S', # ATX headers
r'^\[.*?\]\(.*?\)', # Links
r'^```', # Code fences
r'^\|.+\|', # Tables
r'^\s*[-*+]\s+\S', # Lists
r'^>\s+\S', # Blockquotes
r"^#{1,6}\s+\S", # ATX headers
r"^\[.*?\]\(.*?\)", # Links
r"^```", # Code fences
r"^\|.+\|", # Tables
r"^\s*[-*+]\s+\S", # Lists
r"^>\s+\S", # Blockquotes
]
return any(re.search(pattern, content, re.MULTILINE) for pattern in md_indicators)
def _parse_content(self, content: str, source_path: str) -> Document:
"""Parse Markdown content into Document."""
self._lines = content.split('\n')
self._lines = content.split("\n")
self._current_line = 0
document = Document(
title='',
format='markdown',
title="",
format="markdown",
source_path=source_path,
)
@@ -96,12 +105,12 @@ class MarkdownParser(BaseParser):
self._current_line += 1
# Extract title from first h1 or frontmatter
if document.meta.get('title'):
document.title = document.meta['title']
if document.meta.get("title"):
document.title = document.meta["title"]
else:
for block in document.blocks:
if block.type == ContentBlockType.HEADING:
heading_data = block.metadata.get('heading_data')
heading_data = block.metadata.get("heading_data")
if heading_data and heading_data.level == 1:
document.title = heading_data.text
break
@@ -117,13 +126,13 @@ class MarkdownParser(BaseParser):
return None
first_line = self._lines[self._current_line].strip()
if first_line != '---':
if first_line != "---":
return None
# Find closing ---
end_line = None
for i in range(self._current_line + 1, len(self._lines)):
if self._lines[i].strip() == '---':
if self._lines[i].strip() == "---":
end_line = i
break
@@ -132,7 +141,7 @@ class MarkdownParser(BaseParser):
# Extract frontmatter content
frontmatter_lines = self._lines[self._current_line + 1 : end_line]
'\n'.join(frontmatter_lines)
"\n".join(frontmatter_lines)
# Simple key: value parsing (not full YAML)
meta = {}
@@ -145,11 +154,11 @@ class MarkdownParser(BaseParser):
continue
# Check for new key
match = re.match(r'^(\w+):\s*(.*)$', stripped)
match = re.match(r"^(\w+):\s*(.*)$", stripped)
if match:
# Save previous key
if current_key:
meta[current_key] = '\n'.join(current_value).strip()
meta[current_key] = "\n".join(current_value).strip()
current_key = match.group(1)
value = match.group(2)
@@ -157,27 +166,27 @@ class MarkdownParser(BaseParser):
# Handle inline value
if value:
# Check if it's a list
if value.startswith('[') and value.endswith(']'):
if value.startswith("[") and value.endswith("]"):
# Parse list
items = [item.strip().strip('"\'') for item in value[1:-1].split(',')]
items = [item.strip().strip("\"'") for item in value[1:-1].split(",")]
meta[current_key] = items
else:
current_value = [value]
else:
current_value = []
elif current_key and stripped.startswith('- '):
elif current_key and stripped.startswith("- "):
# List item
if current_key not in meta:
meta[current_key] = []
if not isinstance(meta[current_key], list):
meta[current_key] = [meta[current_key]]
meta[current_key].append(stripped[2:].strip().strip('"\''))
meta[current_key].append(stripped[2:].strip().strip("\"'"))
elif current_key:
current_value.append(stripped)
# Save last key
if current_key:
meta[current_key] = '\n'.join(current_value).strip()
meta[current_key] = "\n".join(current_value).strip()
# Advance past frontmatter
self._current_line = end_line + 1
@@ -198,11 +207,11 @@ class MarkdownParser(BaseParser):
return None
# Skip HTML comments
if stripped.startswith('<!--'):
if stripped.startswith("<!--"):
return self._parse_html_comment()
# ATX Headers
if stripped.startswith('#'):
if stripped.startswith("#"):
return self._parse_atx_header()
# Setext headers (underline style)
@@ -210,23 +219,23 @@ class MarkdownParser(BaseParser):
return self._parse_setext_header()
# Code fence
if stripped.startswith('```'):
if stripped.startswith("```"):
return self._parse_code_fence()
# Indented code block
if current.startswith(' ') or current.startswith('\t'):
if current.startswith(" ") or current.startswith("\t"):
return self._parse_indented_code()
# Table
if '|' in stripped and self._is_table(line):
if "|" in stripped and self._is_table(line):
return self._parse_table()
# Blockquote (check for admonition)
if stripped.startswith('>'):
if stripped.startswith(">"):
return self._parse_blockquote()
# Horizontal rule
if re.match(r'^[\-*_]{3,}\s*$', stripped):
if re.match(r"^[\-*_]{3,}\s*$", stripped):
return self._parse_horizontal_rule()
# List
@@ -249,18 +258,18 @@ class MarkdownParser(BaseParser):
return False
# H1: ===, H2: ---
return re.match(r'^[=-]+$', next_line) is not None
return re.match(r"^[=-]+$", next_line) is not None
def _parse_atx_header(self) -> ContentBlock:
"""Parse ATX style header (# Header)."""
line = self._lines[self._current_line]
match = re.match(r'^(#{1,6})\s+(.+)$', line.strip())
match = re.match(r"^(#{1,6})\s+(.+)$", line.strip())
if match:
level = len(match.group(1))
text = match.group(2).strip()
# Remove trailing hashes
text = re.sub(r'\s+#+$', '', text)
text = re.sub(r"\s+#+$", "", text)
anchor = self._create_anchor(text)
@@ -274,7 +283,7 @@ class MarkdownParser(BaseParser):
return ContentBlock(
type=ContentBlockType.HEADING,
content=text,
metadata={'heading_data': heading},
metadata={"heading_data": heading},
source_line=self._current_line + 1,
)
@@ -285,7 +294,7 @@ class MarkdownParser(BaseParser):
text = self._lines[self._current_line].strip()
underline = self._lines[self._current_line + 1].strip()
level = 1 if underline[0] == '=' else 2
level = 1 if underline[0] == "=" else 2
anchor = self._create_anchor(text)
heading = Heading(
@@ -301,14 +310,14 @@ class MarkdownParser(BaseParser):
return ContentBlock(
type=ContentBlockType.HEADING,
content=text,
metadata={'heading_data': heading},
metadata={"heading_data": heading},
source_line=self._current_line,
)
def _parse_code_fence(self) -> ContentBlock:
"""Parse fenced code block."""
line = self._lines[self._current_line]
match = re.match(r'^```(\w+)?\s*$', line.strip())
match = re.match(r"^```(\w+)?\s*$", line.strip())
language = match.group(1) if match else None
start_line = self._current_line
@@ -317,19 +326,19 @@ class MarkdownParser(BaseParser):
code_lines = []
while self._current_line < len(self._lines):
current_line = self._lines[self._current_line]
if current_line.strip() == '```':
if current_line.strip() == "```":
break
code_lines.append(current_line)
self._current_line += 1
code = '\n'.join(code_lines)
code = "\n".join(code_lines)
# Detect language if not specified
detected_lang, confidence = self.quality_scorer.detect_language(code)
if not language and confidence > 0.6:
language = detected_lang
elif not language:
language = 'text'
language = "text"
# Score code quality
quality = self.quality_scorer.score_code_block(code, language)
@@ -346,8 +355,8 @@ class MarkdownParser(BaseParser):
type=ContentBlockType.CODE_BLOCK,
content=code,
metadata={
'code_data': code_block,
'language': language,
"code_data": code_block,
"language": language,
},
source_line=start_line + 1,
quality_score=quality,
@@ -361,13 +370,13 @@ class MarkdownParser(BaseParser):
while self._current_line < len(self._lines):
line = self._lines[self._current_line]
if not line.strip():
code_lines.append('')
code_lines.append("")
self._current_line += 1
continue
if line.startswith(' '):
if line.startswith(" "):
code_lines.append(line[4:])
elif line.startswith('\t'):
elif line.startswith("\t"):
code_lines.append(line[1:])
else:
self._current_line -= 1
@@ -375,7 +384,7 @@ class MarkdownParser(BaseParser):
self._current_line += 1
code = '\n'.join(code_lines).rstrip()
code = "\n".join(code_lines).rstrip()
# Detect language
detected_lang, confidence = self.quality_scorer.detect_language(code)
@@ -383,7 +392,7 @@ class MarkdownParser(BaseParser):
code_block = CodeBlock(
code=code,
language=detected_lang if confidence > 0.6 else 'text',
language=detected_lang if confidence > 0.6 else "text",
quality_score=quality,
confidence=confidence,
source_line=start_line + 1,
@@ -393,8 +402,8 @@ class MarkdownParser(BaseParser):
type=ContentBlockType.CODE_BLOCK,
content=code,
metadata={
'code_data': code_block,
'language': detected_lang,
"code_data": code_block,
"language": detected_lang,
},
source_line=start_line + 1,
quality_score=quality,
@@ -409,7 +418,7 @@ class MarkdownParser(BaseParser):
next_line = self._lines[line + 1].strip()
# Check for table separator line
return bool(re.match(r'^[\|:-]+$', next_line) and '|' in current)
return bool(re.match(r"^[\|:-]+$", next_line) and "|" in current)
def _parse_table(self) -> ContentBlock:
"""Parse a GFM table."""
@@ -419,7 +428,7 @@ class MarkdownParser(BaseParser):
# Parse header row
header_line = self._lines[self._current_line].strip()
headers = [cell.strip() for cell in header_line.split('|')]
headers = [cell.strip() for cell in header_line.split("|")]
headers = [h for h in headers if h] # Remove empty
self._current_line += 1
@@ -431,11 +440,11 @@ class MarkdownParser(BaseParser):
while self._current_line < len(self._lines):
line = self._lines[self._current_line].strip()
if not line or '|' not in line:
if not line or "|" not in line:
self._current_line -= 1
break
cells = [cell.strip() for cell in line.split('|')]
cells = [cell.strip() for cell in line.split("|")]
cells = [c for c in cells if c]
if cells:
rows.append(cells)
@@ -446,7 +455,7 @@ class MarkdownParser(BaseParser):
rows=rows,
headers=headers,
caption=None,
source_format='markdown',
source_format="markdown",
source_line=start_line + 1,
)
@@ -455,7 +464,7 @@ class MarkdownParser(BaseParser):
return ContentBlock(
type=ContentBlockType.TABLE,
content=f"[Table: {len(rows)} rows]",
metadata={'table_data': table},
metadata={"table_data": table},
source_line=start_line + 1,
quality_score=quality,
)
@@ -471,15 +480,15 @@ class MarkdownParser(BaseParser):
line = self._lines[self._current_line]
stripped = line.strip()
if not stripped.startswith('>'):
if not stripped.startswith(">"):
self._current_line -= 1
break
# Remove > prefix
content = line[1:].strip() if line.startswith('> ') else line[1:].strip()
content = line[1:].strip() if line.startswith("> ") else line[1:].strip()
# Check for GitHub-style admonition: > [!NOTE]
admonition_match = re.match(r'^\[!([\w]+)\]\s*(.*)$', content)
admonition_match = re.match(r"^\[!([\w]+)\]\s*(.*)$", content)
if admonition_match and not admonition_type:
type_name = admonition_match.group(1).lower()
admonition_type = self.ADMONITION_TYPES.get(type_name)
@@ -497,17 +506,17 @@ class MarkdownParser(BaseParser):
if admonition_type:
return ContentBlock(
type=ContentBlockType.ADMONITION,
content='\n'.join(admonition_content),
metadata={'admonition_type': admonition_type},
content="\n".join(admonition_content),
metadata={"admonition_type": admonition_type},
source_line=start_line + 1,
)
# Regular blockquote
content = '\n'.join(lines)
content = "\n".join(lines)
return ContentBlock(
type=ContentBlockType.RAW,
content=f"> {content}",
metadata={'block_type': 'blockquote'},
metadata={"block_type": "blockquote"},
source_line=start_line + 1,
)
@@ -519,7 +528,7 @@ class MarkdownParser(BaseParser):
line = self._lines[self._current_line]
content_lines.append(line)
if '-->' in line:
if "-->" in line:
break
self._current_line += 1
@@ -531,16 +540,16 @@ class MarkdownParser(BaseParser):
"""Parse horizontal rule."""
return ContentBlock(
type=ContentBlockType.RAW,
content='---',
metadata={'element': 'horizontal_rule'},
content="---",
metadata={"element": "horizontal_rule"},
source_line=self._current_line + 1,
)
def _detect_list_type(self, stripped: str) -> ListType | None:
"""Detect if line starts a list and which type."""
if re.match(r'^[-*+]\s+', stripped):
if re.match(r"^[-*+]\s+", stripped):
return ListType.BULLET
if re.match(r'^\d+\.\s+', stripped):
if re.match(r"^\d+\.\s+", stripped):
return ListType.NUMBERED
return None
@@ -559,13 +568,13 @@ class MarkdownParser(BaseParser):
# Check if still in list
if list_type == ListType.BULLET:
match = re.match(r'^[-*+]\s+(.+)$', stripped)
match = re.match(r"^[-*+]\s+(.+)$", stripped)
if not match:
self._current_line -= 1
break
items.append(match.group(1))
else: # NUMBERED
match = re.match(r'^\d+\.\s+(.+)$', stripped)
match = re.match(r"^\d+\.\s+(.+)$", stripped)
if not match:
self._current_line -= 1
break
@@ -577,8 +586,8 @@ class MarkdownParser(BaseParser):
type=ContentBlockType.LIST,
content=f"{len(items)} items",
metadata={
'list_type': list_type,
'items': items,
"list_type": list_type,
"items": items,
},
source_line=start_line + 1,
)
@@ -597,15 +606,15 @@ class MarkdownParser(BaseParser):
break
# Check for block-level elements
if stripped.startswith('#'):
if stripped.startswith("#"):
break
if stripped.startswith('```'):
if stripped.startswith("```"):
break
if stripped.startswith('>'):
if stripped.startswith(">"):
break
if stripped.startswith('---') or stripped.startswith('***'):
if stripped.startswith("---") or stripped.startswith("***"):
break
if stripped.startswith('|') and self._is_table(self._current_line):
if stripped.startswith("|") and self._is_table(self._current_line):
break
if self._detect_list_type(stripped):
break
@@ -615,7 +624,7 @@ class MarkdownParser(BaseParser):
lines.append(stripped)
self._current_line += 1
content = ' '.join(lines)
content = " ".join(lines)
# Process inline elements
content = self._process_inline(content)
@@ -629,60 +638,60 @@ class MarkdownParser(BaseParser):
def _process_inline(self, text: str) -> str:
"""Process inline Markdown elements."""
# Links [text](url)
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'[\1](\2)', text)
text = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", r"[\1](\2)", text)
# Images ![alt](url)
text = re.sub(r'!\[([^\]]*)\]\(([^)]+)\)', r'![\1](\2)', text)
text = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", r"![\1](\2)", text)
# Code `code`
text = re.sub(r'`([^`]+)`', r'`\1`', text)
text = re.sub(r"`([^`]+)`", r"`\1`", text)
# Bold **text** or __text__
text = re.sub(r'\*\*([^*]+)\*\*', r'**\1**', text)
text = re.sub(r'__([^_]+)__', r'**\1**', text)
text = re.sub(r"\*\*([^*]+)\*\*", r"**\1**", text)
text = re.sub(r"__([^_]+)__", r"**\1**", text)
# Italic *text* or _text_
text = re.sub(r'(?<!\*)\*([^*]+)\*(?!\*)', r'*\1*', text)
text = re.sub(r'(?<!_)_([^_]+)_(?!_)', r'*\1*', text)
text = re.sub(r"(?<!\*)\*([^*]+)\*(?!\*)", r"*\1*", text)
text = re.sub(r"(?<!_)_([^_]+)_(?!_)", r"*\1*", text)
# Strikethrough ~~text~~
text = re.sub(r'~~([^~]+)~~', r'~~\1~~', text)
text = re.sub(r"~~([^~]+)~~", r"~~\1~~", text)
return text
def _create_anchor(self, text: str) -> str:
"""Create URL anchor from heading text."""
anchor = text.lower()
anchor = re.sub(r'[^\w\s-]', '', anchor)
anchor = anchor.replace(' ', '-')
anchor = re.sub(r'-+', '-', anchor)
return anchor.strip('-')
anchor = re.sub(r"[^\w\s-]", "", anchor)
anchor = anchor.replace(" ", "-")
anchor = re.sub(r"-+", "-", anchor)
return anchor.strip("-")
def _extract_specialized_content(self, document: Document):
"""Extract specialized content lists from blocks."""
for block in document.blocks:
# Extract headings
if block.type == ContentBlockType.HEADING:
heading_data = block.metadata.get('heading_data')
heading_data = block.metadata.get("heading_data")
if heading_data:
document.headings.append(heading_data)
# Extract code blocks
elif block.type == ContentBlockType.CODE_BLOCK:
code_data = block.metadata.get('code_data')
code_data = block.metadata.get("code_data")
if code_data:
document.code_blocks.append(code_data)
# Extract tables
elif block.type == ContentBlockType.TABLE:
table_data = block.metadata.get('table_data')
table_data = block.metadata.get("table_data")
if table_data:
document.tables.append(table_data)
# Extract images from paragraphs (simplified)
elif block.type == ContentBlockType.PARAGRAPH:
content = block.content
img_matches = re.findall(r'!\[([^\]]*)\]\(([^)]+)\)', content)
img_matches = re.findall(r"!\[([^\]]*)\]\(([^)]+)\)", content)
for alt, src in img_matches:
image = Image(
source=src,
@@ -692,12 +701,12 @@ class MarkdownParser(BaseParser):
document.images.append(image)
# Extract links
link_matches = re.findall(r'\[([^\]]+)\]\(([^)]+)\)', content)
link_matches = re.findall(r"\[([^\]]+)\]\(([^)]+)\)", content)
for text, url in link_matches:
# Determine if internal or external
if url.startswith('#'):
if url.startswith("#"):
ref_type = CrossRefType.INTERNAL
elif url.startswith('http'):
elif url.startswith("http"):
ref_type = CrossRefType.EXTERNAL
else:
ref_type = CrossRefType.INTERNAL

View File

@@ -25,6 +25,7 @@ try:
except ImportError:
# Fallback for relative import
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from pdf_extractor_poc import PDFExtractor
@@ -75,9 +76,7 @@ class PdfParser(BaseParser):
This method is mainly for API compatibility.
"""
# For PDF, we need to use parse_file
raise NotImplementedError(
"PDF parsing requires file path. Use parse_file() instead."
)
raise NotImplementedError("PDF parsing requires file path. Use parse_file() instead.")
def parse_file(self, path: str | Path) -> ParseResult:
"""

View File

@@ -17,107 +17,133 @@ class QualityScorer:
# Language patterns for detection and validation
LANGUAGE_PATTERNS = {
'python': {
'keywords': ['def ', 'class ', 'import ', 'from ', 'return ', 'if ', 'for ', 'while'],
'syntax_checks': [
(r':\s*$', 'colon_ending'), # Python uses colons for blocks
(r'def\s+\w+\s*\([^)]*\)\s*:', 'function_def'),
(r'class\s+\w+', 'class_def'),
"python": {
"keywords": ["def ", "class ", "import ", "from ", "return ", "if ", "for ", "while"],
"syntax_checks": [
(r":\s*$", "colon_ending"), # Python uses colons for blocks
(r"def\s+\w+\s*\([^)]*\)\s*:", "function_def"),
(r"class\s+\w+", "class_def"),
],
},
'javascript': {
'keywords': ['function', 'const ', 'let ', 'var ', '=>', 'return ', 'if(', 'for('],
'syntax_checks': [
(r'function\s+\w+\s*\(', 'function_def'),
(r'const\s+\w+\s*=', 'const_decl'),
(r'=>', 'arrow_function'),
"javascript": {
"keywords": ["function", "const ", "let ", "var ", "=>", "return ", "if(", "for("],
"syntax_checks": [
(r"function\s+\w+\s*\(", "function_def"),
(r"const\s+\w+\s*=", "const_decl"),
(r"=>", "arrow_function"),
],
},
'typescript': {
'keywords': ['interface ', 'type ', ': string', ': number', ': boolean', 'implements'],
'syntax_checks': [
(r'interface\s+\w+', 'interface_def'),
(r':\s*(string|number|boolean|any)', 'type_annotation'),
"typescript": {
"keywords": ["interface ", "type ", ": string", ": number", ": boolean", "implements"],
"syntax_checks": [
(r"interface\s+\w+", "interface_def"),
(r":\s*(string|number|boolean|any)", "type_annotation"),
],
},
'java': {
'keywords': ['public ', 'private ', 'class ', 'void ', 'String ', 'int ', 'return '],
'syntax_checks': [
(r'public\s+class\s+\w+', 'class_def'),
(r'public\s+\w+\s+\w+\s*\(', 'method_def'),
"java": {
"keywords": ["public ", "private ", "class ", "void ", "String ", "int ", "return "],
"syntax_checks": [
(r"public\s+class\s+\w+", "class_def"),
(r"public\s+\w+\s+\w+\s*\(", "method_def"),
],
},
'cpp': {
'keywords': ['#include', 'using namespace', 'std::', 'cout', 'cin', 'public:', 'private:'],
'syntax_checks': [
(r'#include\s*[<"]', 'include'),
(r'std::', 'std_namespace'),
"cpp": {
"keywords": [
"#include",
"using namespace",
"std::",
"cout",
"cin",
"public:",
"private:",
],
"syntax_checks": [
(r'#include\s*[<"]', "include"),
(r"std::", "std_namespace"),
],
},
'csharp': {
'keywords': ['namespace ', 'public class', 'private ', 'void ', 'string ', 'int '],
'syntax_checks': [
(r'namespace\s+\w+', 'namespace'),
(r'public\s+class\s+\w+', 'class_def'),
"csharp": {
"keywords": ["namespace ", "public class", "private ", "void ", "string ", "int "],
"syntax_checks": [
(r"namespace\s+\w+", "namespace"),
(r"public\s+class\s+\w+", "class_def"),
],
},
'go': {
'keywords': ['package ', 'func ', 'import ', 'return ', 'if ', 'for ', 'range '],
'syntax_checks': [
(r'func\s+\w+\s*\(', 'function_def'),
(r'package\s+\w+', 'package_decl'),
"go": {
"keywords": ["package ", "func ", "import ", "return ", "if ", "for ", "range "],
"syntax_checks": [
(r"func\s+\w+\s*\(", "function_def"),
(r"package\s+\w+", "package_decl"),
],
},
'rust': {
'keywords': ['fn ', 'let ', 'mut ', 'impl ', 'struct ', 'enum ', 'match ', 'use '],
'syntax_checks': [
(r'fn\s+\w+\s*\(', 'function_def'),
(r'impl\s+\w+', 'impl_block'),
"rust": {
"keywords": ["fn ", "let ", "mut ", "impl ", "struct ", "enum ", "match ", "use "],
"syntax_checks": [
(r"fn\s+\w+\s*\(", "function_def"),
(r"impl\s+\w+", "impl_block"),
],
},
'gdscript': { # Godot
'keywords': ['extends ', 'class_name ', 'func ', 'var ', 'const ', 'signal ', 'export', 'onready'],
'syntax_checks': [
(r'extends\s+\w+', 'extends'),
(r'func\s+_\w+', 'built_in_method'),
(r'signal\s+\w+', 'signal_def'),
(r'@export', 'export_annotation'),
"gdscript": { # Godot
"keywords": [
"extends ",
"class_name ",
"func ",
"var ",
"const ",
"signal ",
"export",
"onready",
],
"syntax_checks": [
(r"extends\s+\w+", "extends"),
(r"func\s+_\w+", "built_in_method"),
(r"signal\s+\w+", "signal_def"),
(r"@export", "export_annotation"),
],
},
'yaml': {
'keywords': [],
'syntax_checks': [
(r'^\w+:\s*', 'key_value'),
(r'^-\s+\w+', 'list_item'),
"yaml": {
"keywords": [],
"syntax_checks": [
(r"^\w+:\s*", "key_value"),
(r"^-\s+\w+", "list_item"),
],
},
'json': {
'keywords': [],
'syntax_checks': [
(r'["\']\w+["\']\s*:', 'key_value'),
(r'\{[^}]*\}', 'object'),
(r'\[[^\]]*\]', 'array'),
"json": {
"keywords": [],
"syntax_checks": [
(r'["\']\w+["\']\s*:', "key_value"),
(r"\{[^}]*\}", "object"),
(r"\[[^\]]*\]", "array"),
],
},
'xml': {
'keywords': [],
'syntax_checks': [
(r'<\w+[^>]*>', 'opening_tag'),
(r'</\w+>', 'closing_tag'),
"xml": {
"keywords": [],
"syntax_checks": [
(r"<\w+[^>]*>", "opening_tag"),
(r"</\w+>", "closing_tag"),
],
},
'sql': {
'keywords': ['SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'TABLE'],
'syntax_checks': [
(r'SELECT\s+.+\s+FROM', 'select_statement'),
(r'CREATE\s+TABLE', 'create_table'),
"sql": {
"keywords": [
"SELECT",
"FROM",
"WHERE",
"INSERT",
"UPDATE",
"DELETE",
"CREATE",
"TABLE",
],
"syntax_checks": [
(r"SELECT\s+.+\s+FROM", "select_statement"),
(r"CREATE\s+TABLE", "create_table"),
],
},
'bash': {
'keywords': ['#!/bin/', 'echo ', 'if [', 'then', 'fi', 'for ', 'do', 'done'],
'syntax_checks': [
(r'#!/bin/\w+', 'shebang'),
(r'\$\w+', 'variable'),
"bash": {
"keywords": ["#!/bin/", "echo ", "if [", "then", "fi", "for ", "do", "done"],
"syntax_checks": [
(r"#!/bin/\w+", "shebang"),
(r"\$\w+", "variable"),
],
},
}
@@ -139,7 +165,7 @@ class QualityScorer:
return 0.0
code = code.strip()
lines = [line for line in code.split('\n') if line.strip()]
lines = [line for line in code.split("\n") if line.strip()]
# Factor 1: Length appropriateness
code_len = len(code)
@@ -161,13 +187,14 @@ class QualityScorer:
lang_patterns = self.LANGUAGE_PATTERNS[language]
# Check for keywords
keyword_matches = sum(1 for kw in lang_patterns['keywords'] if kw in code)
keyword_matches = sum(1 for kw in lang_patterns["keywords"] if kw in code)
if keyword_matches >= 2:
score += 1.0
# Check for syntax patterns
syntax_matches = sum(
1 for pattern, _ in lang_patterns['syntax_checks']
1
for pattern, _ in lang_patterns["syntax_checks"]
if re.search(pattern, code, re.MULTILINE)
)
if syntax_matches >= 1:
@@ -175,11 +202,11 @@ class QualityScorer:
# Factor 4: Structural quality
# Check for function/class definitions
if re.search(r'\b(def|function|func|fn|class|public class)\b', code):
if re.search(r"\b(def|function|func|fn|class|public class)\b", code):
score += 1.5
# Check for meaningful variable names (not just x, y, i)
meaningful_vars = re.findall(r'\b[a-z_][a-z0-9_]{3,}\b', code.lower())
meaningful_vars = re.findall(r"\b[a-z_][a-z0-9_]{3,}\b", code.lower())
if len(meaningful_vars) >= 3:
score += 0.5
@@ -192,8 +219,7 @@ class QualityScorer:
# Factor 6: Comment/code ratio
comment_lines = sum(
1 for line in lines
if line.strip().startswith(('#', '//', '/*', '*', '--', '<!--'))
1 for line in lines if line.strip().startswith(("#", "//", "/*", "*", "--", "<!--"))
)
if len(lines) > 0:
comment_ratio = comment_lines / len(lines)
@@ -210,7 +236,7 @@ class QualityScorer:
issues = []
# Check for balanced braces/brackets
pairs = [('{', '}'), ('[', ']'), ('(', ')')]
pairs = [("{", "}"), ("[", "]"), ("(", ")")]
for open_char, close_char in pairs:
open_count = code.count(open_char)
close_count = code.count(close_char)
@@ -218,26 +244,27 @@ class QualityScorer:
issues.append(f"Unbalanced {open_char}{close_char}")
# Check for common natural language indicators
common_words = ['the', 'and', 'for', 'with', 'this', 'that', 'have', 'from', 'they']
word_count = sum(1 for word in common_words if f' {word} ' in code.lower())
common_words = ["the", "and", "for", "with", "this", "that", "have", "from", "they"]
word_count = sum(1 for word in common_words if f" {word} " in code.lower())
if word_count > 5 and len(code.split()) < 100:
issues.append("May be natural language")
# Language-specific checks
if language == 'python':
if language == "python":
# Check for mixed indentation
indent_chars = set()
for line in code.split('\n'):
if line.startswith(' '):
indent_chars.add('space')
elif line.startswith('\t'):
indent_chars.add('tab')
for line in code.split("\n"):
if line.startswith(" "):
indent_chars.add("space")
elif line.startswith("\t"):
indent_chars.add("tab")
if len(indent_chars) > 1:
issues.append("Mixed tabs and spaces")
elif language == 'json':
elif language == "json":
try:
import json
json.loads(code)
except Exception as e:
issues.append(f"Invalid JSON: {str(e)[:50]}")
@@ -311,7 +338,7 @@ class QualityScorer:
score += 0.5
# Structure check
if '.' in content: # Has sentences
if "." in content: # Has sentences
score += 0.5
if content[0].isupper(): # Starts with capital
score += 0.5
@@ -327,7 +354,7 @@ class QualityScorer:
"""
code = code.strip()
if not code:
return 'unknown', 0.0
return "unknown", 0.0
scores = {}
@@ -335,18 +362,18 @@ class QualityScorer:
score = 0.0
# Check keywords
keyword_hits = sum(1 for kw in patterns['keywords'] if kw in code)
keyword_hits = sum(1 for kw in patterns["keywords"] if kw in code)
score += keyword_hits * 0.5
# Check syntax patterns
for pattern, _ in patterns['syntax_checks']:
for pattern, _ in patterns["syntax_checks"]:
if re.search(pattern, code, re.MULTILINE):
score += 1.0
scores[lang] = score
if not scores:
return 'unknown', 0.0
return "unknown", 0.0
best_lang = max(scores, key=scores.get)
best_score = scores[best_lang]

View File

@@ -21,9 +21,19 @@ from typing import Any
from .base_parser import BaseParser
from .unified_structure import (
Document, ContentBlock, ContentBlockType, CrossReference, CrossRefType,
AdmonitionType, Heading, CodeBlock, Table, Field, DefinitionItem,
Image, ListType
Document,
ContentBlock,
ContentBlockType,
CrossReference,
CrossRefType,
AdmonitionType,
Heading,
CodeBlock,
Table,
Field,
DefinitionItem,
Image,
ListType,
)
from .quality_scorer import QualityScorer
@@ -36,49 +46,71 @@ class RstParser(BaseParser):
"""
# RST header underline characters (in order of level)
HEADER_CHARS = ['=', '-', '~', '^', '"', "'", '`', ':', '.', '_', '*', '+', '#']
HEADER_CHARS = ["=", "-", "~", "^", '"', "'", "`", ":", ".", "_", "*", "+", "#"]
# Admonition directives
ADMONITION_DIRECTIVES = {
'note': AdmonitionType.NOTE,
'warning': AdmonitionType.WARNING,
'tip': AdmonitionType.TIP,
'hint': AdmonitionType.HINT,
'important': AdmonitionType.IMPORTANT,
'caution': AdmonitionType.CAUTION,
'danger': AdmonitionType.DANGER,
'attention': AdmonitionType.ATTENTION,
'error': AdmonitionType.ERROR,
'deprecated': AdmonitionType.DEPRECATED,
'versionadded': AdmonitionType.VERSIONADDED,
'versionchanged': AdmonitionType.VERSIONCHANGED,
"note": AdmonitionType.NOTE,
"warning": AdmonitionType.WARNING,
"tip": AdmonitionType.TIP,
"hint": AdmonitionType.HINT,
"important": AdmonitionType.IMPORTANT,
"caution": AdmonitionType.CAUTION,
"danger": AdmonitionType.DANGER,
"attention": AdmonitionType.ATTENTION,
"error": AdmonitionType.ERROR,
"deprecated": AdmonitionType.DEPRECATED,
"versionadded": AdmonitionType.VERSIONADDED,
"versionchanged": AdmonitionType.VERSIONCHANGED,
}
# Cross-reference patterns
CROSS_REF_PATTERNS = [
(r':ref:`([^`]+)`', CrossRefType.REF),
(r':doc:`([^`]+)`', CrossRefType.DOC),
(r':class:`([^`]+)`', CrossRefType.CLASS),
(r':meth:`([^`]+)`', CrossRefType.METH),
(r':func:`([^`]+)`', CrossRefType.FUNC),
(r':attr:`([^`]+)`', CrossRefType.ATTR),
(r':signal:`([^`]+)`', CrossRefType.SIGNAL), # Godot
(r':enum:`([^`]+)`', CrossRefType.ENUM), # Godot
(r':mod:`([^`]+)`', CrossRefType.MOD),
(r':data:`([^`]+)`', CrossRefType.DATA),
(r':exc:`([^`]+)`', CrossRefType.EXC),
(r":ref:`([^`]+)`", CrossRefType.REF),
(r":doc:`([^`]+)`", CrossRefType.DOC),
(r":class:`([^`]+)`", CrossRefType.CLASS),
(r":meth:`([^`]+)`", CrossRefType.METH),
(r":func:`([^`]+)`", CrossRefType.FUNC),
(r":attr:`([^`]+)`", CrossRefType.ATTR),
(r":signal:`([^`]+)`", CrossRefType.SIGNAL), # Godot
(r":enum:`([^`]+)`", CrossRefType.ENUM), # Godot
(r":mod:`([^`]+)`", CrossRefType.MOD),
(r":data:`([^`]+)`", CrossRefType.DATA),
(r":exc:`([^`]+)`", CrossRefType.EXC),
]
# Field list fields (common in docstrings)
FIELD_NAMES = [
'param', 'parameter', 'arg', 'argument',
'type', 'vartype', 'types',
'returns', 'return', 'rtype', 'returntype',
'raises', 'raise', 'except', 'exception',
'yields', 'yield', 'ytype',
'seealso', 'see', 'note', 'warning',
'todo', 'deprecated', 'versionadded', 'versionchanged',
'args', 'kwargs', 'keyword', 'keywords',
"param",
"parameter",
"arg",
"argument",
"type",
"vartype",
"types",
"returns",
"return",
"rtype",
"returntype",
"raises",
"raise",
"except",
"exception",
"yields",
"yield",
"ytype",
"seealso",
"see",
"note",
"warning",
"todo",
"deprecated",
"versionadded",
"versionchanged",
"args",
"kwargs",
"keyword",
"keywords",
]
def __init__(self, options: dict[str, Any] | None = None):
@@ -90,31 +122,31 @@ class RstParser(BaseParser):
@property
def format_name(self) -> str:
return 'restructuredtext'
return "restructuredtext"
@property
def supported_extensions(self) -> list[str]:
return ['.rst', '.rest']
return [".rst", ".rest"]
def _detect_format(self, content: str) -> bool:
"""Detect if content is RST."""
rst_indicators = [
r'\n[=-~^]+\n', # Underline headers
r'\.\.\s+\w+::', # Directives
r':\w+:`[^`]+`', # Cross-references
r'\.\.\s+_`[^`]+`:', # Targets
r"\n[=-~^]+\n", # Underline headers
r"\.\.\s+\w+::", # Directives
r":\w+:`[^`]+`", # Cross-references
r"\.\.\s+_`[^`]+`:", # Targets
]
return any(re.search(pattern, content) for pattern in rst_indicators)
def _parse_content(self, content: str, source_path: str) -> Document:
"""Parse RST content into Document."""
self._lines = content.split('\n')
self._lines = content.split("\n")
self._current_line = 0
self._substitutions = {}
document = Document(
title='',
format='rst',
title="",
format="rst",
source_path=source_path,
)
@@ -132,7 +164,7 @@ class RstParser(BaseParser):
# Extract title from first heading
for block in document.blocks:
if block.type == ContentBlockType.HEADING:
heading_data = block.metadata.get('heading_data')
heading_data = block.metadata.get("heading_data")
if heading_data:
document.title = heading_data.text
break
@@ -147,7 +179,7 @@ class RstParser(BaseParser):
def _collect_substitutions(self):
"""First pass: collect all substitution definitions."""
pattern = re.compile(r'^\.\.\s+\|([^|]+)\|\s+replace::\s*(.+)$')
pattern = re.compile(r"^\.\.\s+\|([^|]+)\|\s+replace::\s*(.+)$")
for i, line in enumerate(self._lines):
match = pattern.match(line)
if match:
@@ -169,11 +201,12 @@ class RstParser(BaseParser):
return None
# Skip comments
if stripped.startswith('.. ') and '::' not in stripped and not stripped.startswith('.. |'):
if stripped.startswith(".. ") and "::" not in stripped and not stripped.startswith(".. |"):
# Check if it's a comment
next_words = stripped[3:].split()
if (
not next_words or next_words[0] not in self.FIELD_NAMES + list(self.ADMONITION_DIRECTIVES.keys())
not next_words
or next_words[0] not in self.FIELD_NAMES + list(self.ADMONITION_DIRECTIVES.keys())
) and not any(c.isalpha() for c in stripped[3:]):
return None
@@ -182,7 +215,7 @@ class RstParser(BaseParser):
return self._parse_header()
# Directive
if stripped.startswith('.. '):
if stripped.startswith(".. "):
return self._parse_directive()
# Definition list
@@ -194,11 +227,11 @@ class RstParser(BaseParser):
return self._parse_field_list()
# Bullet list
if stripped.startswith(('- ', '* ', '+ ')):
if stripped.startswith(("- ", "* ", "+ ")):
return self._parse_bullet_list()
# Numbered list
if re.match(r'^\d+\.\s', stripped):
if re.match(r"^\d+\.\s", stripped):
return self._parse_numbered_list()
# Paragraph (default)
@@ -235,8 +268,8 @@ class RstParser(BaseParser):
level = self.HEADER_CHARS.index(char) + 1 if char in self.HEADER_CHARS else 1
# Create anchor ID
anchor = text.lower().replace(' ', '-').replace('_', '-')
anchor = re.sub(r'[^a-z0-9-]', '', anchor)
anchor = text.lower().replace(" ", "-").replace("_", "-")
anchor = re.sub(r"[^a-z0-9-]", "", anchor)
heading = Heading(
level=level,
@@ -251,7 +284,7 @@ class RstParser(BaseParser):
return ContentBlock(
type=ContentBlockType.HEADING,
content=text,
metadata={'heading_data': heading},
metadata={"heading_data": heading},
source_line=self._current_line,
)
@@ -261,7 +294,7 @@ class RstParser(BaseParser):
current = self._lines[line].strip()
# Extract directive name
match = re.match(r'^\.\.\s+([\w\-]+)::\s*(.*)$', current)
match = re.match(r"^\.\.\s+([\w\-]+)::\s*(.*)$", current)
if not match:
# Could be a comment or something else
return self._parse_paragraph()
@@ -277,46 +310,44 @@ class RstParser(BaseParser):
current_line = self._lines[self._current_line]
# Check for end of directive (non-indented line or new directive)
if current_line.strip() and not current_line.startswith(' '):
if current_line.strip() and not current_line.startswith(" "):
self._current_line -= 1 # Back up, this line belongs to next block
break
# Collect content (remove common indentation)
if current_line.startswith(' '):
if current_line.startswith(" "):
content_lines.append(current_line[3:])
elif current_line.startswith(' '):
elif current_line.startswith(" "):
content_lines.append(current_line[2:])
elif current_line.startswith(' '):
elif current_line.startswith(" "):
content_lines.append(current_line[1:])
elif current_line.strip():
content_lines.append(current_line)
else:
content_lines.append('')
content_lines.append("")
self._current_line += 1
content = '\n'.join(content_lines).strip()
content = "\n".join(content_lines).strip()
# Route to specific directive handler
if directive_name in self.ADMONITION_DIRECTIVES:
return self._parse_admonition_directive(
directive_name, argument, content, line + 1
)
elif directive_name == 'code-block':
return self._parse_admonition_directive(directive_name, argument, content, line + 1)
elif directive_name == "code-block":
return self._parse_code_block_directive(argument, content, line + 1)
elif directive_name == 'table':
elif directive_name == "table":
return self._parse_table_directive(argument, content, line + 1)
elif directive_name == 'list-table':
elif directive_name == "list-table":
return self._parse_list_table_directive(argument, content, line + 1)
elif directive_name == 'toctree':
elif directive_name == "toctree":
return self._parse_toctree_directive(content, line + 1)
elif directive_name == 'image' or directive_name == 'figure':
elif directive_name == "image" or directive_name == "figure":
return self._parse_image_directive(argument, content, line + 1)
elif directive_name == 'raw':
elif directive_name == "raw":
return ContentBlock(
type=ContentBlockType.RAW,
content=content,
metadata={'directive_name': directive_name, 'format': argument},
metadata={"directive_name": directive_name, "format": argument},
source_line=line + 1,
)
else:
@@ -324,39 +355,40 @@ class RstParser(BaseParser):
return ContentBlock(
type=ContentBlockType.DIRECTIVE,
content=content,
metadata={'directive_name': directive_name, 'argument': argument},
metadata={"directive_name": directive_name, "argument": argument},
source_line=line + 1,
)
def _parse_admonition_directive(self, name: str, argument: str,
content: str, line: int) -> ContentBlock:
def _parse_admonition_directive(
self, name: str, argument: str, content: str, line: int
) -> ContentBlock:
"""Parse an admonition directive (note, warning, etc.)."""
admonition_type = self.ADMONITION_DIRECTIVES.get(name, AdmonitionType.NOTE)
full_content = argument
if content:
full_content += '\n' + content if full_content else content
full_content += "\n" + content if full_content else content
return ContentBlock(
type=ContentBlockType.ADMONITION,
content=full_content,
metadata={
'admonition_type': admonition_type,
'directive_name': name,
"admonition_type": admonition_type,
"directive_name": name,
},
source_line=line,
)
def _parse_code_block_directive(self, language: str, content: str, line: int) -> ContentBlock:
"""Parse a code-block directive."""
lang = language.strip() or 'text'
lang = language.strip() or "text"
# Score the code
quality = self.quality_scorer.score_code_block(content, lang)
detected_lang, confidence = self.quality_scorer.detect_language(content)
# Use detected language if none specified and confidence is high
if lang == 'text' and confidence > 0.7:
if lang == "text" and confidence > 0.7:
lang = detected_lang
code_block = CodeBlock(
@@ -371,8 +403,8 @@ class RstParser(BaseParser):
type=ContentBlockType.CODE_BLOCK,
content=content,
metadata={
'code_data': code_block,
'language': lang,
"code_data": code_block,
"language": lang,
},
source_line=line,
quality_score=quality,
@@ -381,7 +413,7 @@ class RstParser(BaseParser):
def _parse_table_directive(self, caption: str, content: str, line: int) -> ContentBlock:
"""Parse a table directive (simple or grid table)."""
# Try to detect table type from content
if '+--' in content or '+==' in content:
if "+--" in content or "+==" in content:
table = self._parse_grid_table(content, caption, line)
else:
table = self._parse_simple_table(content, caption, line)
@@ -392,16 +424,15 @@ class RstParser(BaseParser):
type=ContentBlockType.TABLE,
content=f"[Table: {caption}]" if caption else "[Table]",
metadata={
'table_data': table,
"table_data": table,
},
source_line=line,
quality_score=quality,
)
def _parse_simple_table(self, content: str, caption: str | None,
line: int) -> Table:
def _parse_simple_table(self, content: str, caption: str | None, line: int) -> Table:
"""Parse a simple RST table (space-separated columns with = or - separators)."""
lines = content.split('\n')
lines = content.split("\n")
rows = []
headers = None
separator_indices = []
@@ -412,9 +443,9 @@ class RstParser(BaseParser):
# Match separator lines that contain = or - but no alphanumeric chars
if (
stripped
and re.match(r'^[\s=-]+$', stripped)
and any(c in stripped for c in '=-')
and re.search(r'={3,}|-{3,}', stripped)
and re.match(r"^[\s=-]+$", stripped)
and any(c in stripped for c in "=-")
and re.search(r"={3,}|-{3,}", stripped)
):
separator_indices.append(i)
@@ -426,7 +457,7 @@ class RstParser(BaseParser):
in_sep = True
start = 0
for j, char in enumerate(sep_line):
if char in '= -':
if char in "= -":
if not in_sep:
col_boundaries.append((start, j))
in_sep = True
@@ -442,18 +473,18 @@ class RstParser(BaseParser):
stripped = line_text.strip()
# Skip separator lines (handle both simple and grid table separators)
if re.match(r'^[\s=-]+$', stripped) and any(c in stripped for c in '=-'):
if re.match(r"^[\s=-]+$", stripped) and any(c in stripped for c in "=-"):
continue
if not stripped:
continue
if '|' in line_text:
if "|" in line_text:
# Pipe-delimited format
cells = [cell.strip() for cell in line_text.split('|')]
cells = [cell.strip() for cell in line_text.split("|")]
cells = [c for c in cells if c]
# Skip if all cells look like separators
if cells and not all(re.match(r'^[\s=-]+$', c) for c in cells):
if cells and not all(re.match(r"^[\s=-]+$", c) for c in cells):
rows.append(cells)
elif col_boundaries:
# Use column boundaries from separator
@@ -466,7 +497,7 @@ class RstParser(BaseParser):
rows.append(cells)
else:
# Fallback: split by 2+ spaces
cells = [cell.strip() for cell in re.split(r'\s{2,}', stripped)]
cells = [cell.strip() for cell in re.split(r"\s{2,}", stripped)]
cells = [c for c in cells if c]
if cells:
rows.append(cells)
@@ -482,9 +513,11 @@ class RstParser(BaseParser):
if i > first_sep and lines[i].strip():
# Check if this is a separator
stripped = lines[i].strip()
is_sep = bool(re.match(r'^[\s=-]+$', stripped) and
any(c in stripped for c in '=-') and
re.search(r'={3,}|-{3,}', stripped))
is_sep = bool(
re.match(r"^[\s=-]+$", stripped)
and any(c in stripped for c in "=-")
and re.search(r"={3,}|-{3,}", stripped)
)
if not is_sep:
first_row_idx = i
break
@@ -506,33 +539,32 @@ class RstParser(BaseParser):
rows=rows,
headers=headers,
caption=caption,
source_format='simple',
source_format="simple",
source_line=line,
)
def _parse_grid_table(self, content: str, caption: str | None,
line: int) -> Table:
def _parse_grid_table(self, content: str, caption: str | None, line: int) -> Table:
"""Parse a grid RST table."""
lines = content.split('\n')
lines = content.split("\n")
rows = []
headers = None
in_header = False
for i, line_text in enumerate(lines):
# Check for header separator (+=...=+)
if re.match(r'^\+[=+]+\+$', line_text.strip()):
if re.match(r"^\+[=+]+\+$", line_text.strip()):
in_header = True
continue
# Check for row separator (+-...-+)
if re.match(r'^\+[-+]+\+$', line_text.strip()):
if re.match(r"^\+[-+]+\+$", line_text.strip()):
in_header = False
continue
# Parse row
if '|' in line_text:
if "|" in line_text:
cells = []
parts = line_text.split('|')[1:-1] # Remove edges
parts = line_text.split("|")[1:-1] # Remove edges
for part in parts:
cell = part.strip()
if cell:
@@ -547,21 +579,20 @@ class RstParser(BaseParser):
rows=rows,
headers=headers,
caption=caption,
source_format='grid',
source_format="grid",
source_line=line,
)
def _parse_list_table_directive(self, caption: str, content: str,
line: int) -> ContentBlock:
def _parse_list_table_directive(self, caption: str, content: str, line: int) -> ContentBlock:
"""Parse a list-table directive."""
lines = content.split('\n')
lines = content.split("\n")
rows = []
headers = None
# Check for :header-rows: option
header_rows = 0
for line_text in lines:
match = re.match(r'^:header-rows:\s*(\d+)', line_text.strip())
match = re.match(r"^:header-rows:\s*(\d+)", line_text.strip())
if match:
header_rows = int(match.group(1))
break
@@ -572,13 +603,13 @@ class RstParser(BaseParser):
stripped = line_text.strip()
# New row
if re.match(r'^\*\s+-', stripped):
if re.match(r"^\*\s+-", stripped):
if current_row:
rows.append(current_row)
current_row = []
# Cell content
if stripped.startswith('- '):
if stripped.startswith("- "):
cell = stripped[2:].strip()
current_row.append(cell)
@@ -594,7 +625,7 @@ class RstParser(BaseParser):
rows=rows,
headers=headers,
caption=caption,
source_format='list-table',
source_format="list-table",
source_line=line,
)
@@ -603,7 +634,7 @@ class RstParser(BaseParser):
return ContentBlock(
type=ContentBlockType.TABLE,
content=f"[Table: {caption}]" if caption else "[Table]",
metadata={'table_data': table},
metadata={"table_data": table},
source_line=line,
quality_score=quality,
)
@@ -612,16 +643,18 @@ class RstParser(BaseParser):
"""Parse a toctree directive."""
entries = []
for line_text in content.split('\n'):
for line_text in content.split("\n"):
stripped = line_text.strip()
# Entries are simple lines or :hidden: etc options
if stripped and not stripped.startswith(':'):
if stripped and not stripped.startswith(":"):
entries.append(stripped)
return ContentBlock(
type=ContentBlockType.TOC_TREE,
content=f"ToC: {', '.join(entries[:5])}..." if len(entries) > 5 else f"ToC: {', '.join(entries)}",
metadata={'entries': entries},
content=f"ToC: {', '.join(entries[:5])}..."
if len(entries) > 5
else f"ToC: {', '.join(entries)}",
metadata={"entries": entries},
source_line=line,
)
@@ -632,14 +665,14 @@ class RstParser(BaseParser):
width = None
height = None
for line_text in content.split('\n'):
for line_text in content.split("\n"):
stripped = line_text.strip()
if stripped.startswith(':alt:'):
if stripped.startswith(":alt:"):
alt_text = stripped[5:].strip()
elif stripped.startswith(':width:'):
elif stripped.startswith(":width:"):
width = stripped[7:].strip()
elif stripped.startswith(':height:'):
elif stripped.startswith(":height:"):
height = stripped[8:].strip()
image = Image(
@@ -653,7 +686,7 @@ class RstParser(BaseParser):
return ContentBlock(
type=ContentBlockType.IMAGE,
content=argument,
metadata={'image_data': image},
metadata={"image_data": image},
source_line=line,
)
@@ -666,7 +699,9 @@ class RstParser(BaseParser):
next_line = self._lines[line + 1].strip()
# Definition list: term followed by indented definition starting with :
return next_line.startswith(': ') or (next_line and next_line[0].isspace() and ':' in current)
return next_line.startswith(": ") or (
next_line and next_line[0].isspace() and ":" in current
)
def _parse_definition_list(self) -> ContentBlock:
"""Parse a definition list."""
@@ -682,13 +717,13 @@ class RstParser(BaseParser):
self._current_line += 1
continue
if not line.startswith(' ') and items:
if not line.startswith(" ") and items:
# New non-indented item, end of list
self._current_line -= 1
break
# Check for term : classifier pattern (RST standard)
match = re.match(r'^([^:]+)\s+:\s+(.+)$', stripped)
match = re.match(r"^([^:]+)\s+:\s+(.+)$", stripped)
if match:
term = match.group(1).strip()
classifier = match.group(2).strip()
@@ -699,31 +734,33 @@ class RstParser(BaseParser):
while self._current_line < len(self._lines):
def_line = self._lines[self._current_line]
if def_line.strip() and not def_line.startswith(' '):
if def_line.strip() and not def_line.startswith(" "):
break
if def_line.startswith(' '):
if def_line.startswith(" "):
def_lines.append(def_line[3:])
elif def_line.startswith(' '):
elif def_line.startswith(" "):
def_lines.append(def_line[2:])
elif def_line.startswith(' '):
elif def_line.startswith(" "):
def_lines.append(def_line[1:])
self._current_line += 1
definition = ' '.join(def_lines).strip()
definition = " ".join(def_lines).strip()
items.append(DefinitionItem(
items.append(
DefinitionItem(
term=term,
definition=definition,
classifier=classifier,
source_line=start_line + 1,
))
)
)
else:
self._current_line += 1
return ContentBlock(
type=ContentBlockType.DEFINITION_LIST,
content=f"{len(items)} definitions",
metadata={'items': items},
metadata={"items": items},
source_line=start_line + 1,
)
@@ -732,7 +769,7 @@ class RstParser(BaseParser):
current = self._lines[line].strip()
# Field list: :fieldname: or :fieldname arg:
return re.match(r'^:(\w+)(\s+\w+)?:', current) is not None
return re.match(r"^:(\w+)(\s+\w+)?:", current) is not None
def _parse_field_list(self) -> ContentBlock:
"""Parse a field list."""
@@ -748,11 +785,11 @@ class RstParser(BaseParser):
self._current_line += 1
continue
if not line.startswith(':') and fields:
if not line.startswith(":") and fields:
break
# Parse field
match = re.match(r'^:(\w+)(?:\s+(\S+))?:(.*)$', stripped)
match = re.match(r"^:(\w+)(?:\s+(\S+))?:(.*)$", stripped)
if match:
name = match.group(1)
arg = match.group(2)
@@ -764,35 +801,39 @@ class RstParser(BaseParser):
while self._current_line < len(self._lines):
cont_line = self._lines[self._current_line]
if cont_line.strip() and not cont_line.startswith(' '):
if cont_line.strip() and not cont_line.startswith(" "):
break
if cont_line.startswith(' '):
if cont_line.startswith(" "):
content_lines.append(cont_line[3:])
elif cont_line.startswith(' '):
elif cont_line.startswith(" "):
content_lines.append(cont_line[2:])
elif cont_line.startswith(' '):
elif cont_line.startswith(" "):
content_lines.append(cont_line[1:])
self._current_line += 1
full_content = ' '.join(content_lines).strip()
full_content = " ".join(content_lines).strip()
fields.append(Field(
fields.append(
Field(
name=name,
arg=arg,
content=full_content,
source_line=start_line + 1,
))
)
)
else:
self._current_line += 1
# Back up one line if we broke on a non-field
if self._current_line < len(self._lines) and not self._lines[self._current_line].strip().startswith(':'):
if self._current_line < len(self._lines) and not self._lines[
self._current_line
].strip().startswith(":"):
self._current_line -= 1
return ContentBlock(
type=ContentBlockType.FIELD_LIST,
content=f"{len(fields)} fields",
metadata={'fields': fields},
metadata={"fields": fields},
source_line=start_line + 1,
)
@@ -809,7 +850,7 @@ class RstParser(BaseParser):
self._current_line += 1
continue
if not stripped.startswith(('- ', '* ', '+ ')):
if not stripped.startswith(("- ", "* ", "+ ")):
self._current_line -= 1
break
@@ -821,8 +862,8 @@ class RstParser(BaseParser):
type=ContentBlockType.LIST,
content=f"{len(items)} items",
metadata={
'list_type': ListType.BULLET,
'items': items,
"list_type": ListType.BULLET,
"items": items,
},
source_line=start_line + 1,
)
@@ -840,7 +881,7 @@ class RstParser(BaseParser):
self._current_line += 1
continue
match = re.match(r'^\d+\.\s+(.+)$', stripped)
match = re.match(r"^\d+\.\s+(.+)$", stripped)
if not match:
self._current_line -= 1
break
@@ -852,8 +893,8 @@ class RstParser(BaseParser):
type=ContentBlockType.LIST,
content=f"{len(items)} items",
metadata={
'list_type': ListType.NUMBERED,
'items': items,
"list_type": ListType.NUMBERED,
"items": items,
},
source_line=start_line + 1,
)
@@ -872,7 +913,7 @@ class RstParser(BaseParser):
break
# Check for special constructs
if stripped.startswith('.. ') or stripped.startswith(': '):
if stripped.startswith(".. ") or stripped.startswith(": "):
break
if self._is_header(self._current_line):
break
@@ -880,7 +921,7 @@ class RstParser(BaseParser):
lines.append(line)
self._current_line += 1
raw_content = ' '.join(lines).strip()
raw_content = " ".join(lines).strip()
# Extract cross-references from raw content before processing
xrefs, ext_links = self._extract_xrefs_from_text(raw_content, start_line + 1)
@@ -896,32 +937,32 @@ class RstParser(BaseParser):
# Store extracted references in metadata
if xrefs or ext_links:
block.metadata['cross_references'] = xrefs
block.metadata['external_links'] = ext_links
block.metadata["cross_references"] = xrefs
block.metadata["external_links"] = ext_links
return block
def _process_inline_markup(self, text: str) -> str:
"""Process inline RST markup."""
# Bold: **text** or *text*
text = re.sub(r'\*\*([^*]+)\*\*', r'**\1**', text)
text = re.sub(r"\*\*([^*]+)\*\*", r"**\1**", text)
# Italic: *text*
text = re.sub(r'(?<!\*)\*([^*]+)\*(?!\*)', r'*\1*', text)
text = re.sub(r"(?<!\*)\*([^*]+)\*(?!\*)", r"*\1*", text)
# Inline code: ``text``
text = re.sub(r'``([^`]+)``', r'`\1`', text)
text = re.sub(r"``([^`]+)``", r"`\1`", text)
# Links: `text <url>`_ -> [text](url)
text = re.sub(r'`([^<]+)<([^>]+)>`_', r'[\1](\2)', text)
text = re.sub(r"`([^<]+)<([^>]+)>`_", r"[\1](\2)", text)
# Cross-references: :type:`target` -> [target]
for pattern, ref_type in self.CROSS_REF_PATTERNS:
text = re.sub(pattern, r'[\1]', text)
text = re.sub(pattern, r"[\1]", text)
# Substitutions: |name| -> value
for name, value in self._substitutions.items():
text = text.replace(f'|{name}|', value)
text = text.replace(f"|{name}|", value)
return text
@@ -930,25 +971,25 @@ class RstParser(BaseParser):
for block in document.blocks:
# Extract headings
if block.type == ContentBlockType.HEADING:
heading_data = block.metadata.get('heading_data')
heading_data = block.metadata.get("heading_data")
if heading_data:
document.headings.append(heading_data)
# Extract code blocks
elif block.type == ContentBlockType.CODE_BLOCK:
code_data = block.metadata.get('code_data')
code_data = block.metadata.get("code_data")
if code_data:
document.code_blocks.append(code_data)
# Extract tables
elif block.type == ContentBlockType.TABLE:
table_data = block.metadata.get('table_data')
table_data = block.metadata.get("table_data")
if table_data:
document.tables.append(table_data)
# Extract cross-references from various sources
elif block.type == ContentBlockType.CROSS_REFERENCE:
xref_data = block.metadata.get('xref_data')
xref_data = block.metadata.get("xref_data")
if xref_data:
if xref_data.ref_type in (CrossRefType.REF, CrossRefType.DOC):
document.internal_links.append(xref_data)
@@ -957,33 +998,33 @@ class RstParser(BaseParser):
# Extract field lists
elif block.type == ContentBlockType.FIELD_LIST:
fields = block.metadata.get('fields', [])
fields = block.metadata.get("fields", [])
if fields:
document.field_lists.append(fields)
# Extract definition lists
elif block.type == ContentBlockType.DEFINITION_LIST:
items = block.metadata.get('items', [])
items = block.metadata.get("items", [])
if items:
document.definition_lists.append(items)
# Extract ToC trees
elif block.type == ContentBlockType.TOC_TREE:
entries = block.metadata.get('entries', [])
entries = block.metadata.get("entries", [])
if entries:
document.toc_trees.append(entries)
# Extract images
elif block.type == ContentBlockType.IMAGE:
image_data = block.metadata.get('image_data')
image_data = block.metadata.get("image_data")
if image_data:
document.images.append(image_data)
# Extract cross-references and links from paragraphs
elif block.type == ContentBlockType.PARAGRAPH:
# Get pre-extracted references from metadata
xrefs = block.metadata.get('cross_references', [])
ext_links = block.metadata.get('external_links', [])
xrefs = block.metadata.get("cross_references", [])
ext_links = block.metadata.get("external_links", [])
document.internal_links.extend(xrefs)
document.external_links.extend(ext_links)
@@ -1004,7 +1045,7 @@ class RstParser(BaseParser):
xrefs.append(xref)
# Extract external links (`text <url>`_)
for match in re.finditer(r'`([^<]+)<([^>]+)>`_', text):
for match in re.finditer(r"`([^<]+)<([^>]+)>`_", text):
link_text = match.group(1).strip()
url = match.group(2).strip()
xref = CrossReference(

View File

@@ -13,6 +13,7 @@ from enum import Enum
class ContentBlockType(Enum):
"""Standardized content block types across all formats."""
HEADING = "heading"
PARAGRAPH = "paragraph"
CODE_BLOCK = "code_block"
@@ -33,6 +34,7 @@ class ContentBlockType(Enum):
class CrossRefType(Enum):
"""Types of cross-references (mainly RST but useful for others)."""
REF = "ref" # :ref:`label`
DOC = "doc" # :doc:`path`
CLASS = "class" # :class:`ClassName`
@@ -50,6 +52,7 @@ class CrossRefType(Enum):
class AdmonitionType(Enum):
"""Types of admonitions/callouts."""
NOTE = "note"
WARNING = "warning"
TIP = "tip"
@@ -66,6 +69,7 @@ class AdmonitionType(Enum):
class ListType(Enum):
"""Types of lists."""
BULLET = "bullet"
NUMBERED = "numbered"
DEFINITION = "definition" # Term/definition pairs
@@ -74,6 +78,7 @@ class ListType(Enum):
@dataclass
class Heading:
"""A document heading/section title."""
level: int # 1-6 for h1-h6, or 1+ for RST underline levels
text: str
id: str | None = None # Anchor ID
@@ -83,6 +88,7 @@ class Heading:
@dataclass
class CodeBlock:
"""A code block with metadata."""
code: str
language: str | None = None
quality_score: float | None = None # 0-10
@@ -96,6 +102,7 @@ class CodeBlock:
@dataclass
class Table:
"""A table with rows and cells."""
rows: list[list[str]] # 2D array of cell content
headers: list[str] | None = None
caption: str | None = None
@@ -118,6 +125,7 @@ class Table:
@dataclass
class CrossReference:
"""A cross-reference link."""
ref_type: CrossRefType
target: str # Target ID, URL, or path
text: str | None = None # Display text (if different from target)
@@ -128,6 +136,7 @@ class CrossReference:
@dataclass
class Field:
"""A field in a field list (RST :param:, :returns:, etc.)."""
name: str # Field name (e.g., 'param', 'returns', 'type')
arg: str | None = None # Field argument (e.g., parameter name)
content: str = "" # Field content
@@ -137,6 +146,7 @@ class Field:
@dataclass
class DefinitionItem:
"""A definition list item (term + definition)."""
term: str
definition: str
classifier: str | None = None # RST classifier (term : classifier)
@@ -146,6 +156,7 @@ class DefinitionItem:
@dataclass
class Image:
"""An image reference or embedded image."""
source: str # URL, path, or base64 data
alt_text: str | None = None
width: int | None = None
@@ -157,6 +168,7 @@ class Image:
@dataclass
class ContentBlock:
"""Universal content block - used by ALL parsers."""
type: ContentBlockType
content: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
@@ -176,6 +188,7 @@ class ContentBlock:
@dataclass
class ExtractionStats:
"""Statistics about document extraction."""
total_blocks: int = 0
code_blocks: int = 0
tables: int = 0
@@ -194,6 +207,7 @@ class Document:
This class provides a standardized representation of document content
regardless of the source format (RST, Markdown, PDF, HTML).
"""
title: str = ""
format: str = "" # 'markdown', 'rst', 'pdf', 'html', 'unknown'
source_path: str = ""
@@ -241,6 +255,7 @@ class Document:
Markdown-formatted string
"""
from .formatters import MarkdownFormatter
formatter = MarkdownFormatter(options or {})
return formatter.format(self)
@@ -256,10 +271,7 @@ class Document:
"source_path": self.source_path,
"format": self.format,
"content": self._extract_content_text(),
"headings": [
{"level": h.level, "text": h.text, "id": h.id}
for h in self.headings
],
"headings": [{"level": h.level, "text": h.text, "id": h.id} for h in self.headings],
"code_samples": [
{
"code": cb.code,
@@ -290,7 +302,7 @@ class Document:
"code_blocks": self.stats.code_blocks,
"tables": self.stats.tables,
"headings": self.stats.headings,
}
},
}
def _extract_content_text(self) -> str:
@@ -317,7 +329,7 @@ class Document:
for block in self.blocks:
if block.type == ContentBlockType.HEADING:
heading_data = block.metadata.get('heading_data')
heading_data = block.metadata.get("heading_data")
if heading_data and heading_data.text == heading_text:
in_section = True
section_level = heading_data.level
@@ -342,6 +354,7 @@ class Document:
def find_tables_by_caption(self, pattern: str) -> list[Table]:
"""Find tables with captions matching a pattern."""
import re
return [t for t in self.tables if t.caption and re.search(pattern, t.caption, re.I)]
def get_api_summary(self) -> dict[str, Any]:
@@ -359,11 +372,11 @@ class Document:
for table in self.tables:
if table.caption:
cap_lower = table.caption.lower()
if 'property' in cap_lower:
if "property" in cap_lower:
properties_table = table
elif 'method' in cap_lower:
elif "method" in cap_lower:
methods_table = table
elif 'signal' in cap_lower:
elif "signal" in cap_lower:
signals_table = table
return {
@@ -385,7 +398,7 @@ class Document:
item = {"name": row[0]}
for i, header in enumerate(headers[1:], 1):
if i < len(row):
item[header.lower().replace(' ', '_')] = row[i]
item[header.lower().replace(" ", "_")] = row[i]
results.append(item)
return results

View File

@@ -80,6 +80,4 @@ class WorkflowsParser(SubcommandParser):
"validate",
help="Parse and validate a workflow by name or file path",
)
validate_p.add_argument(
"workflow_name", help="Workflow name or path to YAML file"
)
validate_p.add_argument("workflow_name", help="Workflow name or path to YAML file")

View File

@@ -712,7 +712,9 @@ def main():
print("=" * 80)
if workflow_executed:
print(f" Running after workflow: {workflow_name}")
print(" (Workflow provides specialized analysis, enhancement provides general improvements)")
print(
" (Workflow provides specialized analysis, enhancement provides general improvements)"
)
print(" (Use --enhance-workflow for more control)")
print("")
# Note: PDF scraper uses enhance_level instead of enhance/enhance_local

View File

@@ -81,9 +81,7 @@ class UnifiedEnhancer:
if config:
self.config = config
else:
self.config = EnhancementConfig(
mode=mode, api_key=api_key, enabled=enabled
)
self.config = EnhancementConfig(mode=mode, api_key=api_key, enabled=enabled)
# Get settings from config manager
if CONFIG_AVAILABLE:
@@ -115,9 +113,7 @@ class UnifiedEnhancer:
self.client = anthropic.Anthropic(**client_kwargs)
logger.info("✅ AI enhancement enabled (using Claude API)")
except ImportError:
logger.warning(
"⚠️ anthropic package not installed, falling back to LOCAL mode"
)
logger.warning("⚠️ anthropic package not installed, falling back to LOCAL mode")
self.config.mode = "local"
except Exception as e:
logger.warning(
@@ -170,13 +166,9 @@ class UnifiedEnhancer:
# Batch processing
batch_size = (
self.config.batch_size
if self.config.mode == "local"
else 5 # API uses smaller batches
)
parallel_workers = (
self.config.parallel_workers if self.config.mode == "local" else 1
self.config.batch_size if self.config.mode == "local" else 5 # API uses smaller batches
)
parallel_workers = self.config.parallel_workers if self.config.mode == "local" else 1
logger.info(
f"🤖 Enhancing {len(items)} {enhancement_type}s with AI "
@@ -200,9 +192,7 @@ class UnifiedEnhancer:
logger.info(f"✅ Enhanced {len(enhanced)} {enhancement_type}s")
return enhanced
def _enhance_parallel(
self, batches: list[list[dict]], prompt_template: str
) -> list[dict]:
def _enhance_parallel(self, batches: list[list[dict]], prompt_template: str) -> list[dict]:
"""Process batches in parallel using ThreadPoolExecutor."""
results = [None] * len(batches) # Preserve order
@@ -234,9 +224,7 @@ class UnifiedEnhancer:
enhanced.extend(batch_result)
return enhanced
def _enhance_batch(
self, items: list[dict], prompt_template: str
) -> list[dict]:
def _enhance_batch(self, items: list[dict], prompt_template: str) -> list[dict]:
"""Enhance a batch of items."""
# Prepare prompt
item_descriptions = []
@@ -244,9 +232,7 @@ class UnifiedEnhancer:
desc = self._format_item_for_prompt(idx, item)
item_descriptions.append(desc)
prompt = prompt_template.format(
items="\n".join(item_descriptions), count=len(items)
)
prompt = prompt_template.format(items="\n".join(item_descriptions), count=len(items))
# Call AI
response = self._call_claude(prompt, max_tokens=3000)
@@ -267,9 +253,7 @@ class UnifiedEnhancer:
if "confidence_boost" in analysis and "confidence" in item:
boost = analysis["confidence_boost"]
if -0.2 <= boost <= 0.2:
item["confidence"] = min(
1.0, max(0.0, item["confidence"] + boost)
)
item["confidence"] = min(1.0, max(0.0, item["confidence"] + boost))
return items

View File

@@ -114,11 +114,7 @@ def run_workflows(
logger.info(f"\n🔗 Chaining {total} workflow(s) in sequence")
for idx, workflow_name in enumerate(named_workflows, 1):
header = (
f"\n{'=' * 80}\n"
f"🔄 Workflow {idx}/{total}: {workflow_name}\n"
f"{'=' * 80}"
)
header = f"\n{'=' * 80}\n🔄 Workflow {idx}/{total}: {workflow_name}\n{'=' * 80}"
logger.info(header)
try:
@@ -143,6 +139,7 @@ def run_workflows(
except Exception as exc:
logger.error(f"❌ Workflow '{workflow_name}' failed: {exc}")
import traceback
traceback.print_exc()
# ── Inline workflow ────────────────────────────────────────────────────
@@ -171,6 +168,7 @@ def run_workflows(
except Exception as exc:
logger.error(f"❌ Inline workflow failed: {exc}")
import traceback
traceback.print_exc()
if dry_run:

View File

@@ -80,9 +80,7 @@ def _list_user_workflow_names() -> list[str]:
"""Return names of user workflows (without extension) from USER_WORKFLOWS_DIR."""
if not USER_WORKFLOWS_DIR.exists():
return []
return sorted(
p.stem for p in USER_WORKFLOWS_DIR.iterdir() if p.suffix in (".yaml", ".yml")
)
return sorted(p.stem for p in USER_WORKFLOWS_DIR.iterdir() if p.suffix in (".yaml", ".yml"))
def cmd_list() -> int:
@@ -155,7 +153,9 @@ def cmd_copy(names: list[str]) -> int:
dest.write_text(text, encoding="utf-8")
print(f"Copied '{name}' to: {dest}")
print(f"Edit it with your favourite editor, then reference it as '--enhance-workflow {name}'")
print(
f"Edit it with your favourite editor, then reference it as '--enhance-workflow {name}'"
)
return rc

View File

@@ -24,6 +24,7 @@ except ImportError:
self.type = type
self.text = text
USER_WORKFLOWS_DIR = Path.home() / ".config" / "skill-seekers" / "workflows"
@@ -50,9 +51,7 @@ def _bundled_names() -> list[str]:
def _user_names() -> list[str]:
if not USER_WORKFLOWS_DIR.exists():
return []
return sorted(
p.stem for p in USER_WORKFLOWS_DIR.iterdir() if p.suffix in (".yaml", ".yml")
)
return sorted(p.stem for p in USER_WORKFLOWS_DIR.iterdir() if p.suffix in (".yaml", ".yml"))
def _read_bundled(name: str) -> str | None:

View File

@@ -122,6 +122,7 @@ class TestCreateCommandArgvForwarding:
def _make_args(self, **kwargs):
import argparse
defaults = {
"enhance_workflow": None,
"enhance_stage": None,
@@ -149,6 +150,7 @@ class TestCreateCommandArgvForwarding:
def _collect_argv(self, args):
from skill_seekers.cli.create_command import CreateCommand
cmd = CreateCommand(args)
argv = []
cmd._add_common_args(argv)

View File

@@ -222,7 +222,7 @@ class TestMarkdownParser:
@pytest.fixture
def md_content(self):
return '''---
return """---
title: Test Document
description: A test markdown file
---
@@ -271,7 +271,7 @@ def hello_world():
## Image
![Alt text](image.png)
'''
"""
@pytest.fixture
def parsed_doc(self, md_content):

View File

@@ -337,10 +337,13 @@ class TestRunWorkflowsDryRun:
mock_engine.workflow.description = "desc"
mock_engine.workflow.stages = []
with patch(
with (
patch(
"skill_seekers.cli.enhancement_workflow.WorkflowEngine",
return_value=mock_engine,
), pytest.raises(SystemExit) as exc:
),
pytest.raises(SystemExit) as exc,
):
run_workflows(args)
assert exc.value.code == 0
@@ -361,10 +364,13 @@ class TestRunWorkflowsDryRun:
m.workflow.stages = []
engines.append(m)
with patch(
with (
patch(
"skill_seekers.cli.enhancement_workflow.WorkflowEngine",
side_effect=engines,
), pytest.raises(SystemExit):
),
pytest.raises(SystemExit),
):
run_workflows(args)
for engine in engines:

View File

@@ -45,14 +45,13 @@ INVALID_YAML_NO_STAGES = textwrap.dedent("""\
# Fixtures & helpers
# ─────────────────────────────────────────────────────────────────────────────
@pytest.fixture
def tmp_user_dir(tmp_path, monkeypatch):
"""Redirect USER_WORKFLOWS_DIR in workflow_tools to a temp dir."""
fake_dir = tmp_path / "workflows"
fake_dir.mkdir()
monkeypatch.setattr(
"skill_seekers.mcp.tools.workflow_tools.USER_WORKFLOWS_DIR", fake_dir
)
monkeypatch.setattr("skill_seekers.mcp.tools.workflow_tools.USER_WORKFLOWS_DIR", fake_dir)
return fake_dir
@@ -66,6 +65,7 @@ def _mock_bundled_names(names=("default", "security-focus")):
def _mock_bundled_text(mapping: dict):
def _read(name):
return mapping.get(name)
return patch(
"skill_seekers.mcp.tools.workflow_tools._read_bundled",
side_effect=_read,
@@ -84,6 +84,7 @@ def _text(result) -> str:
# list_workflows_tool
# ─────────────────────────────────────────────────────────────────────────────
class TestListWorkflowsTool:
def test_lists_bundled_and_user(self, tmp_user_dir):
from skill_seekers.mcp.tools.workflow_tools import list_workflows_tool
@@ -116,6 +117,7 @@ class TestListWorkflowsTool:
# get_workflow_tool
# ─────────────────────────────────────────────────────────────────────────────
class TestGetWorkflowTool:
def test_get_bundled(self):
from skill_seekers.mcp.tools.workflow_tools import get_workflow_tool
@@ -155,6 +157,7 @@ class TestGetWorkflowTool:
# create_workflow_tool
# ─────────────────────────────────────────────────────────────────────────────
class TestCreateWorkflowTool:
def test_create_new_workflow(self, tmp_user_dir):
from skill_seekers.mcp.tools.workflow_tools import create_workflow_tool
@@ -174,9 +177,7 @@ class TestCreateWorkflowTool:
def test_create_invalid_yaml(self, tmp_user_dir):
from skill_seekers.mcp.tools.workflow_tools import create_workflow_tool
result = create_workflow_tool(
{"name": "bad", "content": INVALID_YAML_NO_STAGES}
)
result = create_workflow_tool({"name": "bad", "content": INVALID_YAML_NO_STAGES})
assert "invalid" in _text(result).lower() or "stages" in _text(result).lower()
def test_create_missing_name(self):
@@ -196,6 +197,7 @@ class TestCreateWorkflowTool:
# update_workflow_tool
# ─────────────────────────────────────────────────────────────────────────────
class TestUpdateWorkflowTool:
def test_update_user_workflow(self, tmp_user_dir):
from skill_seekers.mcp.tools.workflow_tools import update_workflow_tool
@@ -203,9 +205,7 @@ class TestUpdateWorkflowTool:
(tmp_user_dir / "my-wf.yaml").write_text("old content", encoding="utf-8")
with _mock_bundled_names([]):
result = update_workflow_tool(
{"name": "my-wf", "content": MINIMAL_YAML}
)
result = update_workflow_tool({"name": "my-wf", "content": MINIMAL_YAML})
text = _text(result)
assert "Updated" in text or "updated" in text.lower()
@@ -215,9 +215,7 @@ class TestUpdateWorkflowTool:
from skill_seekers.mcp.tools.workflow_tools import update_workflow_tool
with _mock_bundled_names(["default"]):
result = update_workflow_tool(
{"name": "default", "content": MINIMAL_YAML}
)
result = update_workflow_tool({"name": "default", "content": MINIMAL_YAML})
assert "bundled" in _text(result).lower()
@@ -227,9 +225,7 @@ class TestUpdateWorkflowTool:
(tmp_user_dir / "my-wf.yaml").write_text(MINIMAL_YAML, encoding="utf-8")
with _mock_bundled_names([]):
result = update_workflow_tool(
{"name": "my-wf", "content": INVALID_YAML_NO_STAGES}
)
result = update_workflow_tool({"name": "my-wf", "content": INVALID_YAML_NO_STAGES})
assert "invalid" in _text(result).lower() or "stages" in _text(result).lower()
@@ -240,9 +236,7 @@ class TestUpdateWorkflowTool:
(tmp_user_dir / "default.yaml").write_text("old", encoding="utf-8")
with _mock_bundled_names(["default"]):
result = update_workflow_tool(
{"name": "default", "content": MINIMAL_YAML}
)
result = update_workflow_tool({"name": "default", "content": MINIMAL_YAML})
text = _text(result)
# User has a file named 'default', so it should succeed
@@ -253,6 +247,7 @@ class TestUpdateWorkflowTool:
# delete_workflow_tool
# ─────────────────────────────────────────────────────────────────────────────
class TestDeleteWorkflowTool:
def test_delete_user_workflow(self, tmp_user_dir):
from skill_seekers.mcp.tools.workflow_tools import delete_workflow_tool

View File

@@ -80,6 +80,7 @@ def sample_yaml_file(tmp_path):
# Helpers
# ─────────────────────────────────────────────────────────────────────────────
def _mock_bundled(names=("default", "minimal", "security-focus")):
"""Patch list_bundled_workflows on the captured module object."""
return patch.object(_wf_cmd, "list_bundled_workflows", return_value=list(names))
@@ -87,8 +88,10 @@ def _mock_bundled(names=("default", "minimal", "security-focus")):
def _mock_bundled_text(name_to_text: dict):
"""Patch _bundled_yaml_text on the captured module object."""
def _bundled_yaml_text(name):
return name_to_text.get(name)
return patch.object(_wf_cmd, "_bundled_yaml_text", side_effect=_bundled_yaml_text)
@@ -96,6 +99,7 @@ def _mock_bundled_text(name_to_text: dict):
# cmd_list
# ─────────────────────────────────────────────────────────────────────────────
class TestCmdList:
def test_shows_bundled_and_user(self, capsys, tmp_user_dir):
(tmp_user_dir / "my-workflow.yaml").write_text(MINIMAL_YAML, encoding="utf-8")
@@ -131,6 +135,7 @@ class TestCmdList:
# cmd_show
# ─────────────────────────────────────────────────────────────────────────────
class TestCmdShow:
def test_show_bundled(self, capsys):
with patch.object(_wf_cmd, "_workflow_yaml_text", return_value=MINIMAL_YAML):
@@ -155,6 +160,7 @@ class TestCmdShow:
# cmd_copy
# ─────────────────────────────────────────────────────────────────────────────
class TestCmdCopy:
def test_copy_bundled_to_user_dir(self, capsys, tmp_user_dir):
with _mock_bundled_text({"security-focus": MINIMAL_YAML}):
@@ -206,6 +212,7 @@ class TestCmdCopy:
# cmd_add
# ─────────────────────────────────────────────────────────────────────────────
class TestCmdAdd:
def test_add_valid_yaml(self, capsys, tmp_user_dir, sample_yaml_file):
rc = cmd_add([str(sample_yaml_file)])
@@ -287,6 +294,7 @@ class TestCmdAdd:
# cmd_remove
# ─────────────────────────────────────────────────────────────────────────────
class TestCmdRemove:
def test_remove_user_workflow(self, capsys, tmp_user_dir):
wf = tmp_user_dir / "my-wf.yaml"
@@ -349,6 +357,7 @@ class TestCmdRemove:
# cmd_validate
# ─────────────────────────────────────────────────────────────────────────────
class TestCmdValidate:
def test_validate_bundled_by_name(self, capsys):
with patch.object(_wf_cmd, "WorkflowEngine") as mock_engine_cls:
@@ -388,6 +397,7 @@ class TestCmdValidate:
# main() entry point
# ─────────────────────────────────────────────────────────────────────────────
class TestMain:
def test_main_no_action_exits_0(self):
from skill_seekers.cli.workflows_command import main
@@ -419,7 +429,10 @@ class TestMain:
assert "name: test-workflow" in capsys.readouterr().out
def test_main_show_not_found_exits_1(self, capsys, tmp_user_dir):
with patch.object(_wf_cmd, "_workflow_yaml_text", return_value=None), pytest.raises(SystemExit) as exc:
with (
patch.object(_wf_cmd, "_workflow_yaml_text", return_value=None),
pytest.raises(SystemExit) as exc,
):
_wf_cmd.main(["show", "ghost"])
assert exc.value.code == 1
@@ -505,12 +518,14 @@ class TestMain:
# Parser argument binding
# ─────────────────────────────────────────────────────────────────────────────
class TestWorkflowsParserArgumentBinding:
"""Verify nargs='+' parsers produce lists with correct attribute names."""
def _parse(self, argv):
"""Parse argv through the standalone main() parser by capturing args."""
import argparse
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="action")