feat: add skill-seekers video --setup for GPU auto-detection and dependency installation
Auto-detects NVIDIA (CUDA), AMD (ROCm), or CPU-only GPU and installs the correct PyTorch variant + easyocr + all visual extraction dependencies. Removes easyocr from video-full pip extras to avoid pulling ~2GB of wrong CUDA packages on non-NVIDIA systems. New files: - video_setup.py (835 lines): GPU detection, PyTorch install, ROCm config, venv checks, system dep validation, module selection, verification - test_video_setup.py (60 tests): Full coverage of detection, install, verify Updated docs: CHANGELOG, AGENTS.md, CLAUDE.md, README.md, CLI_REFERENCE, FAQ, TROUBLESHOOTING, installation guide, video dependency plan All 2523 tests passing (15 skipped). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -495,6 +495,24 @@ VIDEO_ARGUMENTS: dict[str, dict[str, Any]] = {
|
||||
"help": "Use Claude Vision API as fallback for low-confidence code frames (requires ANTHROPIC_API_KEY, ~$0.004/frame)",
|
||||
},
|
||||
},
|
||||
"start_time": {
|
||||
"flags": ("--start-time",),
|
||||
"kwargs": {
|
||||
"type": str,
|
||||
"default": None,
|
||||
"metavar": "TIME",
|
||||
"help": "Start time for extraction (seconds, MM:SS, or HH:MM:SS). Single video only.",
|
||||
},
|
||||
},
|
||||
"end_time": {
|
||||
"flags": ("--end-time",),
|
||||
"kwargs": {
|
||||
"type": str,
|
||||
"default": None,
|
||||
"metavar": "TIME",
|
||||
"help": "End time for extraction (seconds, MM:SS, or HH:MM:SS). Single video only.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Multi-source config specific (from unified_scraper.py)
|
||||
|
||||
@@ -109,6 +109,31 @@ VIDEO_ARGUMENTS: dict[str, dict[str, Any]] = {
|
||||
"help": "Use Claude Vision API as fallback for low-confidence code frames (requires ANTHROPIC_API_KEY, ~$0.004/frame)",
|
||||
},
|
||||
},
|
||||
"start_time": {
|
||||
"flags": ("--start-time",),
|
||||
"kwargs": {
|
||||
"type": str,
|
||||
"default": None,
|
||||
"metavar": "TIME",
|
||||
"help": "Start time for extraction (seconds, MM:SS, or HH:MM:SS). Single video only.",
|
||||
},
|
||||
},
|
||||
"end_time": {
|
||||
"flags": ("--end-time",),
|
||||
"kwargs": {
|
||||
"type": str,
|
||||
"default": None,
|
||||
"metavar": "TIME",
|
||||
"help": "End time for extraction (seconds, MM:SS, or HH:MM:SS). Single video only.",
|
||||
},
|
||||
},
|
||||
"setup": {
|
||||
"flags": ("--setup",),
|
||||
"kwargs": {
|
||||
"action": "store_true",
|
||||
"help": "Auto-detect GPU and install visual extraction deps (PyTorch, easyocr, etc.)",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -398,6 +398,12 @@ class CreateCommand:
|
||||
vs = getattr(self.args, "visual_similarity", None)
|
||||
if vs is not None and vs != 3.0:
|
||||
argv.extend(["--visual-similarity", str(vs)])
|
||||
st = getattr(self.args, "start_time", None)
|
||||
if st is not None:
|
||||
argv.extend(["--start-time", str(st)])
|
||||
et = getattr(self.args, "end_time", None)
|
||||
if et is not None:
|
||||
argv.extend(["--end-time", str(et)])
|
||||
|
||||
# Call video_scraper with modified argv
|
||||
logger.debug(f"Calling video_scraper with argv: {argv}")
|
||||
|
||||
@@ -621,6 +621,11 @@ class VideoInfo:
|
||||
transcript_confidence: float = 0.0
|
||||
content_richness_score: float = 0.0
|
||||
|
||||
# Time-clipping metadata (None when full video is used)
|
||||
original_duration: float | None = None
|
||||
clip_start: float | None = None
|
||||
clip_end: float | None = None
|
||||
|
||||
# Consensus-based text tracking (Phase A-D)
|
||||
text_group_timeline: TextGroupTimeline | None = None
|
||||
audio_visual_alignments: list[AudioVisualAlignment] = field(default_factory=list)
|
||||
@@ -657,6 +662,9 @@ class VideoInfo:
|
||||
"extracted_at": self.extracted_at,
|
||||
"transcript_confidence": self.transcript_confidence,
|
||||
"content_richness_score": self.content_richness_score,
|
||||
"original_duration": self.original_duration,
|
||||
"clip_start": self.clip_start,
|
||||
"clip_end": self.clip_end,
|
||||
"text_group_timeline": self.text_group_timeline.to_dict()
|
||||
if self.text_group_timeline
|
||||
else None,
|
||||
@@ -698,6 +706,9 @@ class VideoInfo:
|
||||
extracted_at=data.get("extracted_at", ""),
|
||||
transcript_confidence=data.get("transcript_confidence", 0.0),
|
||||
content_richness_score=data.get("content_richness_score", 0.0),
|
||||
original_duration=data.get("original_duration"),
|
||||
clip_start=data.get("clip_start"),
|
||||
clip_end=data.get("clip_end"),
|
||||
text_group_timeline=timeline,
|
||||
audio_visual_alignments=[
|
||||
AudioVisualAlignment.from_dict(a) for a in data.get("audio_visual_alignments", [])
|
||||
@@ -739,6 +750,10 @@ class VideoSourceConfig:
|
||||
# Subtitle files
|
||||
subtitle_patterns: list[str] | None = None
|
||||
|
||||
# Time-clipping (single video only)
|
||||
clip_start: float | None = None
|
||||
clip_end: float | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> VideoSourceConfig:
|
||||
return cls(
|
||||
@@ -758,6 +773,8 @@ class VideoSourceConfig:
|
||||
max_segment_duration=data.get("max_segment_duration", 600.0),
|
||||
categories=data.get("categories"),
|
||||
subtitle_patterns=data.get("subtitle_patterns"),
|
||||
clip_start=data.get("clip_start"),
|
||||
clip_end=data.get("clip_end"),
|
||||
)
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
@@ -774,6 +791,23 @@ class VideoSourceConfig:
|
||||
)
|
||||
if sources_set > 1:
|
||||
errors.append("Video source must specify exactly one source type")
|
||||
|
||||
# Clip range validation
|
||||
has_clip = self.clip_start is not None or self.clip_end is not None
|
||||
if has_clip and self.playlist is not None:
|
||||
errors.append(
|
||||
"--start-time/--end-time cannot be used with --playlist. "
|
||||
"Clip range is for single videos only."
|
||||
)
|
||||
if (
|
||||
self.clip_start is not None
|
||||
and self.clip_end is not None
|
||||
and self.clip_start >= self.clip_end
|
||||
):
|
||||
errors.append(
|
||||
f"--start-time ({self.clip_start}s) must be before --end-time ({self.clip_end}s)"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
|
||||
@@ -83,9 +83,15 @@ def check_video_dependencies(require_full: bool = False) -> None:
|
||||
if missing:
|
||||
deps = ", ".join(missing)
|
||||
extra = "[video-full]" if require_full else "[video]"
|
||||
setup_hint = (
|
||||
"\nFor visual deps (GPU-aware): skill-seekers video --setup"
|
||||
if require_full
|
||||
else ""
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Missing video dependencies: {deps}\n"
|
||||
f'Install with: pip install "skill-seekers{extra}"\n'
|
||||
f'Install with: pip install "skill-seekers{extra}"'
|
||||
f"{setup_hint}\n"
|
||||
f"Or: pip install {' '.join(missing)}"
|
||||
)
|
||||
|
||||
@@ -105,6 +111,45 @@ def _sanitize_filename(title: str, max_length: int = 60) -> str:
|
||||
return name[:max_length]
|
||||
|
||||
|
||||
def parse_time_to_seconds(time_str: str) -> float:
|
||||
"""Parse a time string into seconds.
|
||||
|
||||
Accepted formats:
|
||||
- Plain seconds: ``"330"`` or ``"330.5"``
|
||||
- MM:SS: ``"5:30"``
|
||||
- HH:MM:SS: ``"00:05:30"``
|
||||
|
||||
Args:
|
||||
time_str: Time string in one of the accepted formats.
|
||||
|
||||
Returns:
|
||||
Time in seconds as a float.
|
||||
|
||||
Raises:
|
||||
ValueError: If *time_str* cannot be parsed.
|
||||
"""
|
||||
time_str = time_str.strip()
|
||||
if not time_str:
|
||||
raise ValueError("Empty time string")
|
||||
|
||||
parts = time_str.split(":")
|
||||
try:
|
||||
if len(parts) == 1:
|
||||
return float(parts[0])
|
||||
if len(parts) == 2:
|
||||
minutes, seconds = float(parts[0]), float(parts[1])
|
||||
return minutes * 60 + seconds
|
||||
if len(parts) == 3:
|
||||
hours, minutes, seconds = float(parts[0]), float(parts[1]), float(parts[2])
|
||||
return hours * 3600 + minutes * 60 + seconds
|
||||
except ValueError:
|
||||
pass
|
||||
raise ValueError(
|
||||
f"Invalid time format: '{time_str}'. "
|
||||
"Use seconds (330), MM:SS (5:30), or HH:MM:SS (00:05:30)"
|
||||
)
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS or MM:SS."""
|
||||
total = int(seconds)
|
||||
@@ -221,6 +266,10 @@ class VideoToSkillConverter:
|
||||
self.visual_similarity = config.get("visual_similarity", 3.0)
|
||||
self.vision_ocr = config.get("vision_ocr", False)
|
||||
|
||||
# Time-clipping (seconds, None = full video)
|
||||
self.start_time: float | None = config.get("start_time")
|
||||
self.end_time: float | None = config.get("end_time")
|
||||
|
||||
# Paths
|
||||
self.skill_dir = config.get("output") or f"output/{self.name}"
|
||||
self.data_file = f"output/{self.name}_video_extracted.json"
|
||||
@@ -265,6 +314,8 @@ class VideoToSkillConverter:
|
||||
languages=self.languages,
|
||||
visual_extraction=self.visual,
|
||||
whisper_model=self.whisper_model,
|
||||
clip_start=self.start_time,
|
||||
clip_end=self.end_time,
|
||||
)
|
||||
|
||||
videos: list[VideoInfo] = []
|
||||
@@ -317,6 +368,37 @@ class VideoToSkillConverter:
|
||||
if transcript_source == TranscriptSource.YOUTUBE_AUTO:
|
||||
video_info.transcript_confidence *= 0.8
|
||||
|
||||
# Apply time clipping to transcript and chapters
|
||||
clip_start = self.start_time
|
||||
clip_end = self.end_time
|
||||
if clip_start is not None or clip_end is not None:
|
||||
cs = clip_start or 0.0
|
||||
ce = clip_end or float("inf")
|
||||
|
||||
# Store original duration before clipping
|
||||
video_info.original_duration = video_info.duration
|
||||
video_info.clip_start = cs
|
||||
video_info.clip_end = clip_end # keep None if not set
|
||||
|
||||
# Filter transcript segments to clip range
|
||||
original_count = len(transcript_segments)
|
||||
transcript_segments = [
|
||||
seg for seg in transcript_segments if seg.end > cs and seg.start < ce
|
||||
]
|
||||
video_info.raw_transcript = transcript_segments
|
||||
logger.info(
|
||||
f" Clipped transcript: {len(transcript_segments)}/{original_count} "
|
||||
f"segments in range {_format_duration(cs)}-{_format_duration(ce) if clip_end else 'end'}"
|
||||
)
|
||||
|
||||
# Filter chapters to clip range
|
||||
if video_info.chapters:
|
||||
video_info.chapters = [
|
||||
ch
|
||||
for ch in video_info.chapters
|
||||
if ch.end_time > cs and ch.start_time < ce
|
||||
]
|
||||
|
||||
# Segment video
|
||||
segments = segment_video(video_info, transcript_segments, source_config)
|
||||
video_info.segments = segments
|
||||
@@ -336,7 +418,12 @@ class VideoToSkillConverter:
|
||||
import tempfile as _tmpmod
|
||||
|
||||
temp_video_dir = _tmpmod.mkdtemp(prefix="ss_video_")
|
||||
video_path = download_video(source, temp_video_dir)
|
||||
video_path = download_video(
|
||||
source,
|
||||
temp_video_dir,
|
||||
clip_start=self.start_time,
|
||||
clip_end=self.end_time,
|
||||
)
|
||||
|
||||
if video_path and os.path.exists(video_path):
|
||||
keyframes, code_blocks, timeline = extract_visual_data(
|
||||
@@ -347,6 +434,8 @@ class VideoToSkillConverter:
|
||||
min_gap=self.visual_min_gap,
|
||||
similarity_threshold=self.visual_similarity,
|
||||
use_vision_api=self.vision_ocr,
|
||||
clip_start=self.start_time,
|
||||
clip_end=self.end_time,
|
||||
)
|
||||
# Attach keyframes to segments
|
||||
for kf in keyframes:
|
||||
@@ -510,7 +599,13 @@ class VideoToSkillConverter:
|
||||
else:
|
||||
meta_parts.append(f"**Source:** {video.channel_name}")
|
||||
if video.duration > 0:
|
||||
meta_parts.append(f"**Duration:** {_format_duration(video.duration)}")
|
||||
dur_str = _format_duration(video.duration)
|
||||
if video.clip_start is not None or video.clip_end is not None:
|
||||
orig = _format_duration(video.original_duration) if video.original_duration else "?"
|
||||
cs = _format_duration(video.clip_start) if video.clip_start is not None else "0:00"
|
||||
ce = _format_duration(video.clip_end) if video.clip_end is not None else orig
|
||||
dur_str = f"{cs} - {ce} (of {orig})"
|
||||
meta_parts.append(f"**Duration:** {dur_str}")
|
||||
if video.upload_date:
|
||||
meta_parts.append(f"**Published:** {video.upload_date}")
|
||||
|
||||
@@ -737,7 +832,21 @@ class VideoToSkillConverter:
|
||||
else:
|
||||
meta.append(video.channel_name)
|
||||
if video.duration > 0:
|
||||
meta.append(_format_duration(video.duration))
|
||||
dur_str = _format_duration(video.duration)
|
||||
if video.clip_start is not None or video.clip_end is not None:
|
||||
orig = (
|
||||
_format_duration(video.original_duration)
|
||||
if video.original_duration
|
||||
else "?"
|
||||
)
|
||||
cs = (
|
||||
_format_duration(video.clip_start)
|
||||
if video.clip_start is not None
|
||||
else "0:00"
|
||||
)
|
||||
ce = _format_duration(video.clip_end) if video.clip_end is not None else orig
|
||||
dur_str = f"Clip {cs}-{ce} (of {orig})"
|
||||
meta.append(dur_str)
|
||||
if video.view_count is not None:
|
||||
meta.append(f"{_format_count(video.view_count)} views")
|
||||
if meta:
|
||||
@@ -817,6 +926,12 @@ Examples:
|
||||
add_video_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --setup: run GPU detection + dependency installation, then exit
|
||||
if getattr(args, "setup", False):
|
||||
from skill_seekers.cli.video_setup import run_setup
|
||||
|
||||
return run_setup(interactive=True)
|
||||
|
||||
# Setup logging
|
||||
log_level = logging.DEBUG if args.verbose else (logging.WARNING if args.quiet else logging.INFO)
|
||||
logging.basicConfig(level=log_level, format="%(levelname)s: %(message)s")
|
||||
@@ -834,6 +949,29 @@ Examples:
|
||||
if not has_source and not has_json:
|
||||
parser.error("Must specify --url, --video-file, --playlist, or --from-json")
|
||||
|
||||
# Parse and validate time clipping
|
||||
raw_start = getattr(args, "start_time", None)
|
||||
raw_end = getattr(args, "end_time", None)
|
||||
clip_start: float | None = None
|
||||
clip_end: float | None = None
|
||||
|
||||
if raw_start is not None:
|
||||
try:
|
||||
clip_start = parse_time_to_seconds(raw_start)
|
||||
except ValueError as exc:
|
||||
parser.error(f"--start-time: {exc}")
|
||||
if raw_end is not None:
|
||||
try:
|
||||
clip_end = parse_time_to_seconds(raw_end)
|
||||
except ValueError as exc:
|
||||
parser.error(f"--end-time: {exc}")
|
||||
|
||||
if clip_start is not None or clip_end is not None:
|
||||
if getattr(args, "playlist", None):
|
||||
parser.error("--start-time/--end-time cannot be used with --playlist")
|
||||
if clip_start is not None and clip_end is not None and clip_start >= clip_end:
|
||||
parser.error(f"--start-time ({clip_start}s) must be before --end-time ({clip_end}s)")
|
||||
|
||||
# Build config
|
||||
config = {
|
||||
"name": args.name or "video_skill",
|
||||
@@ -849,6 +987,8 @@ Examples:
|
||||
"visual_min_gap": getattr(args, "visual_min_gap", 0.5),
|
||||
"visual_similarity": getattr(args, "visual_similarity", 3.0),
|
||||
"vision_ocr": getattr(args, "vision_ocr", False),
|
||||
"start_time": clip_start,
|
||||
"end_time": clip_end,
|
||||
}
|
||||
|
||||
converter = VideoToSkillConverter(config)
|
||||
@@ -862,6 +1002,10 @@ Examples:
|
||||
logger.info(f" name: {config['name']}")
|
||||
logger.info(f" languages: {config['languages']}")
|
||||
logger.info(f" visual: {config['visual']}")
|
||||
if clip_start is not None or clip_end is not None:
|
||||
start_str = _format_duration(clip_start) if clip_start is not None else "start"
|
||||
end_str = _format_duration(clip_end) if clip_end is not None else "end"
|
||||
logger.info(f" clip range: {start_str} - {end_str}")
|
||||
return 0
|
||||
|
||||
# Workflow 1: Build from JSON
|
||||
|
||||
@@ -132,6 +132,8 @@ def segment_by_time_window(
|
||||
video_info: VideoInfo,
|
||||
transcript_segments: list[TranscriptSegment],
|
||||
window_seconds: float = 120.0,
|
||||
start_offset: float = 0.0,
|
||||
end_limit: float | None = None,
|
||||
) -> list[VideoSegment]:
|
||||
"""Segment video using fixed time windows.
|
||||
|
||||
@@ -139,6 +141,8 @@ def segment_by_time_window(
|
||||
video_info: Video metadata.
|
||||
transcript_segments: Raw transcript segments.
|
||||
window_seconds: Duration of each window in seconds.
|
||||
start_offset: Start segmentation at this time (seconds).
|
||||
end_limit: Stop segmentation at this time (seconds). None = full duration.
|
||||
|
||||
Returns:
|
||||
List of VideoSegment objects.
|
||||
@@ -149,10 +153,13 @@ def segment_by_time_window(
|
||||
if duration <= 0 and transcript_segments:
|
||||
duration = max(seg.end for seg in transcript_segments)
|
||||
|
||||
if end_limit is not None:
|
||||
duration = min(duration, end_limit)
|
||||
|
||||
if duration <= 0:
|
||||
return segments
|
||||
|
||||
current_time = 0.0
|
||||
current_time = start_offset
|
||||
index = 0
|
||||
|
||||
while current_time < duration:
|
||||
@@ -215,4 +222,10 @@ def segment_video(
|
||||
# Fallback to time-window
|
||||
window = config.time_window_seconds
|
||||
logger.info(f"Using time-window segmentation ({window}s windows)")
|
||||
return segment_by_time_window(video_info, transcript_segments, window)
|
||||
return segment_by_time_window(
|
||||
video_info,
|
||||
transcript_segments,
|
||||
window,
|
||||
start_offset=config.clip_start or 0.0,
|
||||
end_limit=config.clip_end,
|
||||
)
|
||||
|
||||
835
src/skill_seekers/cli/video_setup.py
Normal file
835
src/skill_seekers/cli/video_setup.py
Normal file
@@ -0,0 +1,835 @@
|
||||
"""GPU auto-detection and video dependency installation.
|
||||
|
||||
Detects NVIDIA (CUDA) or AMD (ROCm) GPUs using system tools (without
|
||||
requiring torch to be installed) and installs the correct PyTorch variant
|
||||
plus all visual extraction dependencies (easyocr, opencv, etc.).
|
||||
|
||||
Also handles:
|
||||
- Virtual environment creation (if not already in one)
|
||||
- System dependency checks (tesseract binary)
|
||||
- ROCm environment variable configuration (MIOPEN_FIND_MODE)
|
||||
|
||||
Usage:
|
||||
skill-seekers video --setup # Interactive (all modules)
|
||||
skill-seekers video --setup # Interactive, choose modules
|
||||
From MCP: run_setup(interactive=False)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import venv
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Structures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class GPUVendor(Enum):
|
||||
"""Detected GPU hardware vendor."""
|
||||
|
||||
NVIDIA = "nvidia"
|
||||
AMD = "amd"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUInfo:
|
||||
"""Result of GPU auto-detection."""
|
||||
|
||||
vendor: GPUVendor
|
||||
name: str = ""
|
||||
compute_version: str = ""
|
||||
index_url: str = ""
|
||||
details: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetupModules:
|
||||
"""Which modules to install during setup."""
|
||||
|
||||
torch: bool = True
|
||||
easyocr: bool = True
|
||||
opencv: bool = True
|
||||
tesseract: bool = True
|
||||
scenedetect: bool = True
|
||||
whisper: bool = True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PyTorch Index URL Mapping
|
||||
# =============================================================================
|
||||
|
||||
_PYTORCH_BASE = "https://download.pytorch.org/whl"
|
||||
|
||||
|
||||
def _cuda_version_to_index_url(version: str) -> str:
|
||||
"""Map a CUDA version string to the correct PyTorch index URL."""
|
||||
try:
|
||||
parts = version.split(".")
|
||||
major = int(parts[0])
|
||||
minor = int(parts[1]) if len(parts) > 1 else 0
|
||||
ver = major + minor / 10.0
|
||||
except (ValueError, IndexError):
|
||||
return f"{_PYTORCH_BASE}/cpu"
|
||||
|
||||
if ver >= 12.4:
|
||||
return f"{_PYTORCH_BASE}/cu124"
|
||||
if ver >= 12.1:
|
||||
return f"{_PYTORCH_BASE}/cu121"
|
||||
if ver >= 11.8:
|
||||
return f"{_PYTORCH_BASE}/cu118"
|
||||
return f"{_PYTORCH_BASE}/cpu"
|
||||
|
||||
|
||||
def _rocm_version_to_index_url(version: str) -> str:
|
||||
"""Map a ROCm version string to the correct PyTorch index URL."""
|
||||
try:
|
||||
parts = version.split(".")
|
||||
major = int(parts[0])
|
||||
minor = int(parts[1]) if len(parts) > 1 else 0
|
||||
ver = major + minor / 10.0
|
||||
except (ValueError, IndexError):
|
||||
return f"{_PYTORCH_BASE}/cpu"
|
||||
|
||||
if ver >= 6.3:
|
||||
return f"{_PYTORCH_BASE}/rocm6.3"
|
||||
if ver >= 6.0:
|
||||
return f"{_PYTORCH_BASE}/rocm6.2.4"
|
||||
return f"{_PYTORCH_BASE}/cpu"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# GPU Detection (without torch)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def detect_gpu() -> GPUInfo:
|
||||
"""Detect GPU vendor and compute version using system tools.
|
||||
|
||||
Detection order:
|
||||
1. nvidia-smi -> NVIDIA + CUDA version
|
||||
2. rocminfo -> AMD + ROCm version
|
||||
3. lspci -> AMD GPU present but no ROCm (warn)
|
||||
4. Fallback -> CPU-only
|
||||
"""
|
||||
# 1. Check NVIDIA
|
||||
nvidia = _check_nvidia()
|
||||
if nvidia is not None:
|
||||
return nvidia
|
||||
|
||||
# 2. Check AMD ROCm
|
||||
amd = _check_amd_rocm()
|
||||
if amd is not None:
|
||||
return amd
|
||||
|
||||
# 3. Check if AMD GPU exists but ROCm isn't installed
|
||||
amd_no_rocm = _check_amd_lspci()
|
||||
if amd_no_rocm is not None:
|
||||
return amd_no_rocm
|
||||
|
||||
# 4. CPU fallback
|
||||
return GPUInfo(
|
||||
vendor=GPUVendor.NONE,
|
||||
name="CPU-only",
|
||||
index_url=f"{_PYTORCH_BASE}/cpu",
|
||||
details=["No GPU detected, will use CPU-only PyTorch"],
|
||||
)
|
||||
|
||||
|
||||
def _check_nvidia() -> GPUInfo | None:
|
||||
"""Detect NVIDIA GPU via nvidia-smi."""
|
||||
if not shutil.which("nvidia-smi"):
|
||||
return None
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["nvidia-smi"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
output = result.stdout
|
||||
# Parse CUDA version from "CUDA Version: X.Y"
|
||||
cuda_match = re.search(r"CUDA Version:\s*(\d+\.\d+)", output)
|
||||
cuda_ver = cuda_match.group(1) if cuda_match else ""
|
||||
|
||||
# Parse GPU name from the table row (e.g., "NVIDIA GeForce RTX 4090")
|
||||
gpu_name = ""
|
||||
name_match = re.search(r"\|\s+(NVIDIA[^\|]+?)\s+(?:On|Off)\s+\|", output)
|
||||
if name_match:
|
||||
gpu_name = name_match.group(1).strip()
|
||||
|
||||
index_url = _cuda_version_to_index_url(cuda_ver) if cuda_ver else f"{_PYTORCH_BASE}/cpu"
|
||||
|
||||
return GPUInfo(
|
||||
vendor=GPUVendor.NVIDIA,
|
||||
name=gpu_name or "NVIDIA GPU",
|
||||
compute_version=cuda_ver,
|
||||
index_url=index_url,
|
||||
details=[f"CUDA {cuda_ver}" if cuda_ver else "CUDA version unknown"],
|
||||
)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _check_amd_rocm() -> GPUInfo | None:
|
||||
"""Detect AMD GPU via rocminfo."""
|
||||
if not shutil.which("rocminfo"):
|
||||
return None
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["rocminfo"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
output = result.stdout
|
||||
# Parse GPU name from "Name: gfx..." or "Marketing Name: ..."
|
||||
gpu_name = ""
|
||||
marketing_match = re.search(r"Marketing Name:\s*(.+)", output)
|
||||
if marketing_match:
|
||||
gpu_name = marketing_match.group(1).strip()
|
||||
|
||||
# Get ROCm version from /opt/rocm/.info/version
|
||||
rocm_ver = _read_rocm_version()
|
||||
|
||||
index_url = _rocm_version_to_index_url(rocm_ver) if rocm_ver else f"{_PYTORCH_BASE}/cpu"
|
||||
|
||||
return GPUInfo(
|
||||
vendor=GPUVendor.AMD,
|
||||
name=gpu_name or "AMD GPU",
|
||||
compute_version=rocm_ver,
|
||||
index_url=index_url,
|
||||
details=[f"ROCm {rocm_ver}" if rocm_ver else "ROCm version unknown"],
|
||||
)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _read_rocm_version() -> str:
|
||||
"""Read ROCm version from /opt/rocm/.info/version."""
|
||||
try:
|
||||
with open("/opt/rocm/.info/version") as f:
|
||||
return f.read().strip().split("-")[0]
|
||||
except (OSError, IOError):
|
||||
return ""
|
||||
|
||||
|
||||
def _check_amd_lspci() -> GPUInfo | None:
|
||||
"""Detect AMD GPU via lspci when ROCm isn't installed."""
|
||||
if not shutil.which("lspci"):
|
||||
return None
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["lspci"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
# Look for AMD/ATI VGA or Display controllers
|
||||
for line in result.stdout.splitlines():
|
||||
if ("VGA" in line or "Display" in line) and ("AMD" in line or "ATI" in line):
|
||||
return GPUInfo(
|
||||
vendor=GPUVendor.AMD,
|
||||
name=line.split(":")[-1].strip() if ":" in line else "AMD GPU",
|
||||
compute_version="",
|
||||
index_url=f"{_PYTORCH_BASE}/cpu",
|
||||
details=[
|
||||
"AMD GPU detected but ROCm is not installed",
|
||||
"Install ROCm first for GPU acceleration: https://rocm.docs.amd.com/",
|
||||
"Falling back to CPU-only PyTorch",
|
||||
],
|
||||
)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Virtual Environment
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def is_in_venv() -> bool:
|
||||
"""Check if the current Python process is running inside a venv."""
|
||||
return sys.prefix != sys.base_prefix
|
||||
|
||||
|
||||
def create_venv(venv_path: str = ".venv") -> bool:
|
||||
"""Create a virtual environment and return True on success."""
|
||||
path = Path(venv_path).resolve()
|
||||
if path.exists():
|
||||
logger.info(f"Venv already exists at {path}")
|
||||
return True
|
||||
try:
|
||||
venv.create(str(path), with_pip=True)
|
||||
return True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(f"Failed to create venv: {exc}")
|
||||
return False
|
||||
|
||||
|
||||
def get_venv_python(venv_path: str = ".venv") -> str:
|
||||
"""Return the python executable path inside a venv."""
|
||||
path = Path(venv_path).resolve()
|
||||
if platform.system() == "Windows":
|
||||
return str(path / "Scripts" / "python.exe")
|
||||
return str(path / "bin" / "python")
|
||||
|
||||
|
||||
def get_venv_activate_cmd(venv_path: str = ".venv") -> str:
|
||||
"""Return the shell command to activate the venv."""
|
||||
path = Path(venv_path).resolve()
|
||||
if platform.system() == "Windows":
|
||||
return str(path / "Scripts" / "activate")
|
||||
return f"source {path}/bin/activate"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# System Dependency Checks
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _detect_distro() -> str:
|
||||
"""Detect Linux distro family for install command suggestions."""
|
||||
try:
|
||||
with open("/etc/os-release") as f:
|
||||
content = f.read().lower()
|
||||
if "arch" in content or "manjaro" in content or "endeavour" in content:
|
||||
return "arch"
|
||||
if "debian" in content or "ubuntu" in content or "mint" in content or "pop" in content:
|
||||
return "debian"
|
||||
if "fedora" in content or "rhel" in content or "centos" in content or "rocky" in content:
|
||||
return "fedora"
|
||||
if "opensuse" in content or "suse" in content:
|
||||
return "suse"
|
||||
except OSError:
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _get_tesseract_install_cmd() -> str:
|
||||
"""Return distro-specific command to install tesseract."""
|
||||
distro = _detect_distro()
|
||||
cmds = {
|
||||
"arch": "sudo pacman -S tesseract tesseract-data-eng",
|
||||
"debian": "sudo apt install tesseract-ocr tesseract-ocr-eng",
|
||||
"fedora": "sudo dnf install tesseract tesseract-langpack-eng",
|
||||
"suse": "sudo zypper install tesseract-ocr tesseract-ocr-traineddata-english",
|
||||
}
|
||||
return cmds.get(distro, "Install tesseract-ocr with your package manager")
|
||||
|
||||
|
||||
def check_tesseract() -> dict[str, bool | str]:
|
||||
"""Check if tesseract binary is installed and has English data.
|
||||
|
||||
Returns dict with keys: installed, has_eng, install_cmd, version.
|
||||
"""
|
||||
result: dict[str, bool | str] = {
|
||||
"installed": False,
|
||||
"has_eng": False,
|
||||
"install_cmd": _get_tesseract_install_cmd(),
|
||||
"version": "",
|
||||
}
|
||||
|
||||
tess_bin = shutil.which("tesseract")
|
||||
if not tess_bin:
|
||||
return result
|
||||
|
||||
result["installed"] = True
|
||||
|
||||
# Get version
|
||||
try:
|
||||
ver = subprocess.run(
|
||||
["tesseract", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
first_line = (ver.stdout or ver.stderr).split("\n")[0]
|
||||
result["version"] = first_line.strip()
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
# Check for eng language data
|
||||
try:
|
||||
langs = subprocess.run(
|
||||
["tesseract", "--list-langs"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
output = langs.stdout + langs.stderr
|
||||
result["has_eng"] = "eng" in output.split()
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ROCm Environment Configuration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def configure_rocm_env() -> list[str]:
|
||||
"""Set environment variables for ROCm/MIOpen to work correctly.
|
||||
|
||||
Returns list of env vars that were set.
|
||||
"""
|
||||
changes: list[str] = []
|
||||
|
||||
# MIOPEN_FIND_MODE=FAST avoids the workspace allocation issue
|
||||
# where MIOpen requires huge workspace but allocates 0 bytes
|
||||
if "MIOPEN_FIND_MODE" not in os.environ:
|
||||
os.environ["MIOPEN_FIND_MODE"] = "FAST"
|
||||
changes.append("MIOPEN_FIND_MODE=FAST")
|
||||
|
||||
# Ensure MIOpen user DB has a writable location
|
||||
if "MIOPEN_USER_DB_PATH" not in os.environ:
|
||||
db_path = os.path.expanduser("~/.config/miopen")
|
||||
os.makedirs(db_path, exist_ok=True)
|
||||
os.environ["MIOPEN_USER_DB_PATH"] = db_path
|
||||
changes.append(f"MIOPEN_USER_DB_PATH={db_path}")
|
||||
|
||||
return changes
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Installation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
_BASE_VIDEO_DEPS = ["yt-dlp", "youtube-transcript-api"]
|
||||
|
||||
|
||||
def _build_visual_deps(modules: SetupModules) -> list[str]:
|
||||
"""Build the list of pip packages based on selected modules."""
|
||||
# Base video deps are always included — setup must leave video fully ready
|
||||
deps: list[str] = list(_BASE_VIDEO_DEPS)
|
||||
if modules.easyocr:
|
||||
deps.append("easyocr")
|
||||
if modules.opencv:
|
||||
deps.append("opencv-python-headless")
|
||||
if modules.tesseract:
|
||||
deps.append("pytesseract")
|
||||
if modules.scenedetect:
|
||||
deps.append("scenedetect[opencv]")
|
||||
if modules.whisper:
|
||||
deps.append("faster-whisper")
|
||||
return deps
|
||||
|
||||
|
||||
def install_torch(gpu_info: GPUInfo, python_exe: str | None = None) -> bool:
|
||||
"""Install PyTorch with the correct GPU variant.
|
||||
|
||||
Returns True on success, False on failure.
|
||||
"""
|
||||
exe = python_exe or sys.executable
|
||||
cmd = [exe, "-m", "pip", "install", "torch", "torchvision", "--index-url", gpu_info.index_url]
|
||||
logger.info(f"Installing PyTorch from {gpu_info.index_url}")
|
||||
try:
|
||||
result = subprocess.run(cmd, timeout=600, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
logger.error(f"PyTorch install failed:\n{result.stderr[-500:]}")
|
||||
return False
|
||||
return True
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("PyTorch installation timed out (10 min)")
|
||||
return False
|
||||
except OSError as exc:
|
||||
logger.error(f"PyTorch installation error: {exc}")
|
||||
return False
|
||||
|
||||
|
||||
def install_visual_deps(
|
||||
modules: SetupModules | None = None, python_exe: str | None = None
|
||||
) -> bool:
|
||||
"""Install visual extraction dependencies.
|
||||
|
||||
Returns True on success, False on failure.
|
||||
"""
|
||||
mods = modules or SetupModules()
|
||||
deps = _build_visual_deps(mods)
|
||||
if not deps:
|
||||
return True
|
||||
|
||||
exe = python_exe or sys.executable
|
||||
cmd = [exe, "-m", "pip", "install"] + deps
|
||||
logger.info(f"Installing visual deps: {', '.join(deps)}")
|
||||
try:
|
||||
result = subprocess.run(cmd, timeout=600, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Visual deps install failed:\n{result.stderr[-500:]}")
|
||||
return False
|
||||
return True
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Visual deps installation timed out (10 min)")
|
||||
return False
|
||||
except OSError as exc:
|
||||
logger.error(f"Visual deps installation error: {exc}")
|
||||
return False
|
||||
|
||||
|
||||
def install_skill_seekers(python_exe: str) -> bool:
|
||||
"""Install skill-seekers into the target python environment."""
|
||||
cmd = [python_exe, "-m", "pip", "install", "skill-seekers"]
|
||||
try:
|
||||
result = subprocess.run(cmd, timeout=300, capture_output=True, text=True)
|
||||
return result.returncode == 0
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Verification
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def verify_installation() -> dict[str, bool]:
|
||||
"""Verify that all video deps are importable.
|
||||
|
||||
Returns a dict mapping package name to import success.
|
||||
"""
|
||||
results: dict[str, bool] = {}
|
||||
|
||||
# Base video deps
|
||||
try:
|
||||
import yt_dlp # noqa: F401
|
||||
|
||||
results["yt-dlp"] = True
|
||||
except ImportError:
|
||||
results["yt-dlp"] = False
|
||||
|
||||
try:
|
||||
import youtube_transcript_api # noqa: F401
|
||||
|
||||
results["youtube-transcript-api"] = True
|
||||
except ImportError:
|
||||
results["youtube-transcript-api"] = False
|
||||
|
||||
# torch
|
||||
try:
|
||||
import torch
|
||||
|
||||
results["torch"] = True
|
||||
results["torch.cuda"] = torch.cuda.is_available()
|
||||
results["torch.rocm"] = hasattr(torch.version, "hip") and torch.version.hip is not None
|
||||
except ImportError:
|
||||
results["torch"] = False
|
||||
results["torch.cuda"] = False
|
||||
results["torch.rocm"] = False
|
||||
|
||||
# easyocr
|
||||
try:
|
||||
import easyocr # noqa: F401
|
||||
|
||||
results["easyocr"] = True
|
||||
except ImportError:
|
||||
results["easyocr"] = False
|
||||
|
||||
# opencv
|
||||
try:
|
||||
import cv2 # noqa: F401
|
||||
|
||||
results["opencv"] = True
|
||||
except ImportError:
|
||||
results["opencv"] = False
|
||||
|
||||
# pytesseract
|
||||
try:
|
||||
import pytesseract # noqa: F401
|
||||
|
||||
results["pytesseract"] = True
|
||||
except ImportError:
|
||||
results["pytesseract"] = False
|
||||
|
||||
# scenedetect
|
||||
try:
|
||||
import scenedetect # noqa: F401
|
||||
|
||||
results["scenedetect"] = True
|
||||
except ImportError:
|
||||
results["scenedetect"] = False
|
||||
|
||||
# faster-whisper
|
||||
try:
|
||||
import faster_whisper # noqa: F401
|
||||
|
||||
results["faster-whisper"] = True
|
||||
except ImportError:
|
||||
results["faster-whisper"] = False
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Module Selection (Interactive)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _ask_modules(interactive: bool) -> SetupModules:
|
||||
"""Ask the user which modules to install. Returns all if non-interactive."""
|
||||
if not interactive:
|
||||
return SetupModules()
|
||||
|
||||
print("Which modules do you want to install?")
|
||||
print(" [a] All (default)")
|
||||
print(" [c] Choose individually")
|
||||
try:
|
||||
choice = input(" > ").strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print()
|
||||
return SetupModules()
|
||||
|
||||
if choice not in ("c", "choose"):
|
||||
return SetupModules()
|
||||
|
||||
modules = SetupModules()
|
||||
_ask = _interactive_yn
|
||||
|
||||
modules.torch = _ask("PyTorch (required for easyocr GPU)", default=True)
|
||||
modules.easyocr = _ask("EasyOCR (text extraction from video frames)", default=True)
|
||||
modules.opencv = _ask("OpenCV (frame extraction and image processing)", default=True)
|
||||
modules.tesseract = _ask("pytesseract (secondary OCR engine)", default=True)
|
||||
modules.scenedetect = _ask("scenedetect (scene change detection)", default=True)
|
||||
modules.whisper = _ask("faster-whisper (local audio transcription)", default=True)
|
||||
|
||||
return modules
|
||||
|
||||
|
||||
def _interactive_yn(prompt: str, default: bool = True) -> bool:
|
||||
"""Ask a yes/no question, return bool."""
|
||||
suffix = "[Y/n]" if default else "[y/N]"
|
||||
try:
|
||||
answer = input(f" {prompt}? {suffix} ").strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return default
|
||||
if not answer:
|
||||
return default
|
||||
return answer in ("y", "yes")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Orchestrator
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def run_setup(interactive: bool = True) -> int:
|
||||
"""Auto-detect GPU and install all visual extraction dependencies.
|
||||
|
||||
Handles:
|
||||
1. Venv creation (if not in one)
|
||||
2. GPU detection
|
||||
3. Module selection (optional — interactive only)
|
||||
4. System dep checks (tesseract binary)
|
||||
5. ROCm env var configuration
|
||||
6. PyTorch installation (correct GPU variant)
|
||||
7. Visual deps installation
|
||||
8. Verification
|
||||
|
||||
Args:
|
||||
interactive: If True, prompt user for confirmation before installing.
|
||||
|
||||
Returns:
|
||||
0 on success, 1 on failure.
|
||||
"""
|
||||
print("=" * 60)
|
||||
print(" Video Visual Extraction Setup")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
total_steps = 7
|
||||
|
||||
# ── Step 1: Venv check ──
|
||||
print(f"[1/{total_steps}] Checking environment...")
|
||||
if is_in_venv():
|
||||
print(f" Already in venv: {sys.prefix}")
|
||||
python_exe = sys.executable
|
||||
else:
|
||||
print(" Not in a virtual environment.")
|
||||
venv_path = ".venv"
|
||||
if interactive:
|
||||
try:
|
||||
answer = input(
|
||||
f" Create venv at ./{venv_path}? [Y/n] "
|
||||
).strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\nSetup cancelled.")
|
||||
return 1
|
||||
if answer and answer not in ("y", "yes"):
|
||||
print(" Continuing without venv (installing to system Python).")
|
||||
python_exe = sys.executable
|
||||
else:
|
||||
if not create_venv(venv_path):
|
||||
print(" FAILED: Could not create venv.")
|
||||
return 1
|
||||
python_exe = get_venv_python(venv_path)
|
||||
activate_cmd = get_venv_activate_cmd(venv_path)
|
||||
print(f" Venv created at ./{venv_path}")
|
||||
print(f" Installing skill-seekers into venv...")
|
||||
if not install_skill_seekers(python_exe):
|
||||
print(" FAILED: Could not install skill-seekers into venv.")
|
||||
return 1
|
||||
print(f" After setup completes, activate with:")
|
||||
print(f" {activate_cmd}")
|
||||
else:
|
||||
# Non-interactive: use current python
|
||||
python_exe = sys.executable
|
||||
print()
|
||||
|
||||
# ── Step 2: GPU detection ──
|
||||
print(f"[2/{total_steps}] Detecting GPU...")
|
||||
gpu_info = detect_gpu()
|
||||
|
||||
vendor_label = {
|
||||
GPUVendor.NVIDIA: "NVIDIA (CUDA)",
|
||||
GPUVendor.AMD: "AMD (ROCm)",
|
||||
GPUVendor.NONE: "CPU-only",
|
||||
}
|
||||
print(f" GPU: {gpu_info.name}")
|
||||
print(f" Vendor: {vendor_label.get(gpu_info.vendor, gpu_info.vendor.value)}")
|
||||
if gpu_info.compute_version:
|
||||
print(f" Version: {gpu_info.compute_version}")
|
||||
for detail in gpu_info.details:
|
||||
print(f" {detail}")
|
||||
print(f" PyTorch index: {gpu_info.index_url}")
|
||||
print()
|
||||
|
||||
# ── Step 3: Module selection ──
|
||||
print(f"[3/{total_steps}] Selecting modules...")
|
||||
modules = _ask_modules(interactive)
|
||||
deps = _build_visual_deps(modules)
|
||||
print(f" Selected: {', '.join(deps) if deps else '(none)'}")
|
||||
if modules.torch:
|
||||
print(f" + PyTorch + torchvision")
|
||||
print()
|
||||
|
||||
# ── Step 4: System dependency check ──
|
||||
print(f"[4/{total_steps}] Checking system dependencies...")
|
||||
if modules.tesseract:
|
||||
tess = check_tesseract()
|
||||
if not tess["installed"]:
|
||||
print(f" WARNING: tesseract binary not found!")
|
||||
print(f" The pytesseract Python package needs the tesseract binary installed.")
|
||||
print(f" Install it with: {tess['install_cmd']}")
|
||||
print()
|
||||
elif not tess["has_eng"]:
|
||||
print(f" WARNING: tesseract installed ({tess['version']}) but English data missing!")
|
||||
print(f" Install with: {tess['install_cmd']}")
|
||||
print()
|
||||
else:
|
||||
print(f" tesseract: {tess['version']} (eng data OK)")
|
||||
else:
|
||||
print(" tesseract: skipped (not selected)")
|
||||
print()
|
||||
|
||||
# ── Step 5: ROCm configuration ──
|
||||
print(f"[5/{total_steps}] Configuring GPU environment...")
|
||||
if gpu_info.vendor == GPUVendor.AMD:
|
||||
changes = configure_rocm_env()
|
||||
if changes:
|
||||
print(" Set ROCm environment variables:")
|
||||
for c in changes:
|
||||
print(f" {c}")
|
||||
print(" (These fix MIOpen workspace allocation issues)")
|
||||
else:
|
||||
print(" ROCm env vars already configured.")
|
||||
elif gpu_info.vendor == GPUVendor.NVIDIA:
|
||||
print(" NVIDIA: no extra configuration needed.")
|
||||
else:
|
||||
print(" CPU-only: no GPU configuration needed.")
|
||||
print()
|
||||
|
||||
# ── Step 6: Confirm and install ──
|
||||
if interactive:
|
||||
print("Ready to install. Summary:")
|
||||
if modules.torch:
|
||||
print(f" - PyTorch + torchvision (from {gpu_info.index_url})")
|
||||
for dep in deps:
|
||||
print(f" - {dep}")
|
||||
print()
|
||||
try:
|
||||
answer = input("Proceed? [Y/n] ").strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\nSetup cancelled.")
|
||||
return 1
|
||||
if answer and answer not in ("y", "yes"):
|
||||
print("Setup cancelled.")
|
||||
return 1
|
||||
print()
|
||||
|
||||
print(f"[6/{total_steps}] Installing packages...")
|
||||
if modules.torch:
|
||||
print(" Installing PyTorch...")
|
||||
if not install_torch(gpu_info, python_exe):
|
||||
print(" FAILED: PyTorch installation failed.")
|
||||
print(f" Try: {python_exe} -m pip install torch torchvision --index-url {gpu_info.index_url}")
|
||||
return 1
|
||||
print(" PyTorch installed.")
|
||||
|
||||
if deps:
|
||||
print(" Installing visual packages...")
|
||||
if not install_visual_deps(modules, python_exe):
|
||||
print(" FAILED: Visual packages installation failed.")
|
||||
print(f" Try: {python_exe} -m pip install {' '.join(deps)}")
|
||||
return 1
|
||||
print(" Visual packages installed.")
|
||||
print()
|
||||
|
||||
# ── Step 7: Verify ──
|
||||
print(f"[7/{total_steps}] Verifying installation...")
|
||||
results = verify_installation()
|
||||
all_ok = True
|
||||
for pkg, ok in results.items():
|
||||
status = "OK" if ok else "MISSING"
|
||||
print(f" {pkg}: {status}")
|
||||
# torch.cuda / torch.rocm are informational, not required
|
||||
if not ok and pkg not in ("torch.cuda", "torch.rocm"):
|
||||
# Only count as failure if the module was selected
|
||||
if pkg == "torch" and modules.torch:
|
||||
all_ok = False
|
||||
elif pkg == "easyocr" and modules.easyocr:
|
||||
all_ok = False
|
||||
elif pkg == "opencv" and modules.opencv:
|
||||
all_ok = False
|
||||
elif pkg == "pytesseract" and modules.tesseract:
|
||||
all_ok = False
|
||||
elif pkg == "scenedetect" and modules.scenedetect:
|
||||
all_ok = False
|
||||
elif pkg == "faster-whisper" and modules.whisper:
|
||||
all_ok = False
|
||||
|
||||
print()
|
||||
if all_ok:
|
||||
print("Setup complete! You can now use: skill-seekers video --url <URL> --visual")
|
||||
if not is_in_venv() and python_exe != sys.executable:
|
||||
activate_cmd = get_venv_activate_cmd()
|
||||
print(f"\nDon't forget to activate the venv first:")
|
||||
print(f" {activate_cmd}")
|
||||
else:
|
||||
print("Some packages failed to install. Check the output above.")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
@@ -13,6 +13,7 @@ from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import difflib
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
@@ -32,6 +33,18 @@ from skill_seekers.cli.video_models import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Set ROCm/MIOpen env vars BEFORE importing torch (via easyocr).
|
||||
# Without MIOPEN_FIND_MODE=FAST, MIOpen tries to allocate huge workspace
|
||||
# buffers (300MB+), gets 0 bytes, and silently falls back to CPU kernels.
|
||||
import os as _os
|
||||
|
||||
if "MIOPEN_FIND_MODE" not in _os.environ:
|
||||
_os.environ["MIOPEN_FIND_MODE"] = "FAST"
|
||||
if "MIOPEN_USER_DB_PATH" not in _os.environ:
|
||||
_miopen_db = _os.path.expanduser("~/.config/miopen")
|
||||
_os.makedirs(_miopen_db, exist_ok=True)
|
||||
_os.environ["MIOPEN_USER_DB_PATH"] = _miopen_db
|
||||
|
||||
# Tier 2 dependency flags
|
||||
try:
|
||||
import cv2
|
||||
@@ -65,23 +78,46 @@ except ImportError:
|
||||
pytesseract = None # type: ignore[assignment]
|
||||
HAS_PYTESSERACT = False
|
||||
|
||||
# Circuit breaker: after first tesseract failure, disable it for the session.
|
||||
# Prevents wasting time spawning subprocesses that always fail.
|
||||
_tesseract_broken = False
|
||||
|
||||
|
||||
_INSTALL_MSG = (
|
||||
"Visual extraction requires additional dependencies.\n"
|
||||
'Install with: pip install "skill-seekers[video-full]"\n'
|
||||
"Or: pip install opencv-python-headless scenedetect easyocr"
|
||||
"Recommended: skill-seekers video --setup (auto-detects GPU, installs correct PyTorch)\n"
|
||||
'Alternative: pip install "skill-seekers[video-full]" (may install wrong PyTorch variant)'
|
||||
)
|
||||
|
||||
# Lazy-initialized EasyOCR reader (heavy, only load once)
|
||||
_ocr_reader = None
|
||||
|
||||
|
||||
def _detect_gpu() -> bool:
|
||||
"""Check if a CUDA or ROCm GPU is available for EasyOCR/PyTorch."""
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return True
|
||||
# ROCm exposes GPU via torch.version.hip
|
||||
if hasattr(torch.version, "hip") and torch.version.hip is not None:
|
||||
return True
|
||||
return False
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _get_ocr_reader():
|
||||
"""Get or create the EasyOCR reader (lazy singleton)."""
|
||||
global _ocr_reader
|
||||
if _ocr_reader is None:
|
||||
logger.info("Initializing OCR engine (first run may download models)...")
|
||||
_ocr_reader = easyocr.Reader(["en"], gpu=False)
|
||||
use_gpu = _detect_gpu()
|
||||
logger.info(
|
||||
f"Initializing OCR engine ({'GPU' if use_gpu else 'CPU'} mode, "
|
||||
"first run may download models)..."
|
||||
)
|
||||
_ocr_reader = easyocr.Reader(["en"], gpu=use_gpu)
|
||||
return _ocr_reader
|
||||
|
||||
|
||||
@@ -296,11 +332,15 @@ def _run_tesseract_ocr(preprocessed_path: str, frame_type: FrameType) -> list[tu
|
||||
Returns results in the same format as EasyOCR: list of (bbox, text, confidence).
|
||||
Groups words into lines by y-coordinate.
|
||||
|
||||
Uses a circuit breaker: if tesseract fails once, it's disabled for the
|
||||
rest of the session to avoid wasting time on repeated subprocess failures.
|
||||
|
||||
Args:
|
||||
preprocessed_path: Path to the preprocessed grayscale image.
|
||||
frame_type: Frame classification (reserved for future per-type tuning).
|
||||
"""
|
||||
if not HAS_PYTESSERACT:
|
||||
global _tesseract_broken
|
||||
if not HAS_PYTESSERACT or _tesseract_broken:
|
||||
return []
|
||||
|
||||
# Produce clean binary for Tesseract
|
||||
@@ -312,7 +352,11 @@ def _run_tesseract_ocr(preprocessed_path: str, frame_type: FrameType) -> list[tu
|
||||
output_type=pytesseract.Output.DICT,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.debug("pytesseract failed, returning empty results")
|
||||
_tesseract_broken = True
|
||||
logger.warning(
|
||||
"pytesseract failed — disabling for this session. "
|
||||
"Install tesseract binary: skill-seekers video --setup"
|
||||
)
|
||||
return []
|
||||
finally:
|
||||
if binary_path != preprocessed_path and os.path.exists(binary_path):
|
||||
@@ -897,6 +941,25 @@ def _crop_code_region(frame_path: str, bbox: tuple[int, int, int, int], suffix:
|
||||
return cropped_path
|
||||
|
||||
|
||||
def _frame_type_from_regions(
|
||||
regions: list[tuple[int, int, int, int, FrameType]],
|
||||
) -> FrameType:
|
||||
"""Derive the dominant frame type from pre-computed regions.
|
||||
|
||||
Same logic as ``classify_frame`` but avoids re-loading the image.
|
||||
"""
|
||||
for _x1, _y1, _x2, _y2, ft in regions:
|
||||
if ft == FrameType.TERMINAL:
|
||||
return FrameType.TERMINAL
|
||||
if ft == FrameType.CODE_EDITOR:
|
||||
return FrameType.CODE_EDITOR
|
||||
|
||||
from collections import Counter
|
||||
|
||||
type_counts = Counter(ft for _, _, _, _, ft in regions)
|
||||
return type_counts.most_common(1)[0][0] if type_counts else FrameType.OTHER
|
||||
|
||||
|
||||
def classify_frame(frame_path: str) -> FrameType:
|
||||
"""Classify a video frame by its visual content.
|
||||
|
||||
@@ -1114,6 +1177,8 @@ def _compute_frame_timestamps(
|
||||
duration: float,
|
||||
sample_interval: float = 0.7,
|
||||
min_gap: float = 0.5,
|
||||
start_offset: float = 0.0,
|
||||
end_limit: float | None = None,
|
||||
) -> list[float]:
|
||||
"""Build a deduplicated list of timestamps to extract frames at.
|
||||
|
||||
@@ -1126,10 +1191,13 @@ def _compute_frame_timestamps(
|
||||
duration: Total video duration in seconds.
|
||||
sample_interval: Seconds between interval samples.
|
||||
min_gap: Minimum gap between kept timestamps.
|
||||
start_offset: Start sampling at this time (seconds).
|
||||
end_limit: Stop sampling at this time (seconds). None = full duration.
|
||||
|
||||
Returns:
|
||||
Sorted, deduplicated list of timestamps (seconds).
|
||||
"""
|
||||
effective_end = end_limit if end_limit is not None else duration
|
||||
timestamps: set[float] = set()
|
||||
|
||||
# 1. Scene detection — catches cuts, slide transitions, editor switches
|
||||
@@ -1138,19 +1206,21 @@ def _compute_frame_timestamps(
|
||||
scenes = detect_scenes(video_path)
|
||||
for start, _end in scenes:
|
||||
# Take frame 0.5s after the scene starts (avoids transition blur)
|
||||
timestamps.add(round(start + 0.5, 1))
|
||||
ts = round(start + 0.5, 1)
|
||||
if ts >= start_offset and ts < effective_end:
|
||||
timestamps.add(ts)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(f"Scene detection failed, falling back to interval: {exc}")
|
||||
|
||||
# 2. Regular interval sampling — fills gaps between scene cuts
|
||||
t = 0.5 # start slightly after 0 to avoid black intro frames
|
||||
while t < duration:
|
||||
t = max(0.5, start_offset)
|
||||
while t < effective_end:
|
||||
timestamps.add(round(t, 1))
|
||||
t += sample_interval
|
||||
|
||||
# Always include near the end
|
||||
if duration > 2.0:
|
||||
timestamps.add(round(duration - 1.0, 1))
|
||||
if effective_end > 2.0:
|
||||
timestamps.add(round(effective_end - 1.0, 1))
|
||||
|
||||
# 3. Sort and deduplicate (merge timestamps closer than min_gap)
|
||||
sorted_ts = sorted(timestamps)
|
||||
@@ -1876,6 +1946,8 @@ def extract_visual_data(
|
||||
min_gap: float = 0.5,
|
||||
similarity_threshold: float = 3.0,
|
||||
use_vision_api: bool = False,
|
||||
clip_start: float | None = None,
|
||||
clip_end: float | None = None,
|
||||
) -> tuple[list[KeyFrame], list[CodeBlock], TextGroupTimeline | None]:
|
||||
"""Run continuous visual extraction on a video.
|
||||
|
||||
@@ -1899,6 +1971,8 @@ def extract_visual_data(
|
||||
similarity_threshold: Pixel-diff threshold for duplicate detection (default 3.0).
|
||||
use_vision_api: If True, use Claude Vision API as fallback for low-confidence
|
||||
code frames (requires ANTHROPIC_API_KEY).
|
||||
clip_start: Start of clip range in seconds (None = beginning).
|
||||
clip_end: End of clip range in seconds (None = full duration).
|
||||
|
||||
Returns:
|
||||
Tuple of (keyframes, code_blocks, text_group_timeline).
|
||||
@@ -1937,7 +2011,12 @@ def extract_visual_data(
|
||||
|
||||
# Build candidate timestamps
|
||||
timestamps = _compute_frame_timestamps(
|
||||
video_path, duration, sample_interval=sample_interval, min_gap=min_gap
|
||||
video_path,
|
||||
duration,
|
||||
sample_interval=sample_interval,
|
||||
min_gap=min_gap,
|
||||
start_offset=clip_start or 0.0,
|
||||
end_limit=clip_end,
|
||||
)
|
||||
logger.info(f" {len(timestamps)} candidate timestamps after dedup")
|
||||
|
||||
@@ -1961,17 +2040,21 @@ def extract_visual_data(
|
||||
skipped_similar += 1
|
||||
continue
|
||||
prev_frame = frame.copy()
|
||||
frame_h, frame_w = frame.shape[:2]
|
||||
|
||||
# Save frame
|
||||
idx = len(keyframes)
|
||||
frame_filename = f"frame_{idx:03d}_{ts:.0f}s.jpg"
|
||||
frame_path = os.path.join(frames_dir, frame_filename)
|
||||
cv2.imwrite(frame_path, frame)
|
||||
del frame # Free the numpy array early — saved to disk
|
||||
|
||||
# Classify using region-based panel detection
|
||||
regions = classify_frame_regions(frame_path)
|
||||
code_panels = _get_code_panels(regions)
|
||||
frame_type = classify_frame(frame_path) # dominant type for metadata
|
||||
# Derive frame_type from already-computed regions (avoids loading
|
||||
# the image a second time — classify_frame() would repeat the work).
|
||||
frame_type = _frame_type_from_regions(regions)
|
||||
is_code_frame = frame_type in (FrameType.CODE_EDITOR, FrameType.TERMINAL)
|
||||
|
||||
# Per-panel OCR: each code/terminal panel is OCR'd independently
|
||||
@@ -1982,11 +2065,13 @@ def extract_visual_data(
|
||||
ocr_confidence = 0.0
|
||||
|
||||
if is_code_frame and code_panels and (HAS_EASYOCR or HAS_PYTESSERACT):
|
||||
full_area = frame.shape[0] * frame.shape[1]
|
||||
full_area = frame_h * frame_w
|
||||
|
||||
if len(code_panels) > 1:
|
||||
# Parallel OCR — each panel is independent
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(code_panels)) as pool:
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=min(2, len(code_panels))
|
||||
) as pool:
|
||||
futures = {
|
||||
pool.submit(
|
||||
_ocr_single_panel,
|
||||
@@ -2084,8 +2169,8 @@ def extract_visual_data(
|
||||
ocr_text=ocr_text,
|
||||
ocr_regions=ocr_regions,
|
||||
ocr_confidence=ocr_confidence,
|
||||
width=frame.shape[1],
|
||||
height=frame.shape[0],
|
||||
width=frame_w,
|
||||
height=frame_h,
|
||||
sub_sections=sub_sections,
|
||||
)
|
||||
keyframes.append(kf)
|
||||
@@ -2101,6 +2186,10 @@ def extract_visual_data(
|
||||
)
|
||||
)
|
||||
|
||||
# Periodically collect to free PyTorch/numpy memory
|
||||
if idx % 10 == 9:
|
||||
gc.collect()
|
||||
|
||||
cap.release()
|
||||
|
||||
# Finalize text tracking and extract code blocks
|
||||
@@ -2131,7 +2220,12 @@ def extract_visual_data(
|
||||
return keyframes, code_blocks, timeline
|
||||
|
||||
|
||||
def download_video(url: str, output_dir: str) -> str | None:
|
||||
def download_video(
|
||||
url: str,
|
||||
output_dir: str,
|
||||
clip_start: float | None = None,
|
||||
clip_end: float | None = None,
|
||||
) -> str | None:
|
||||
"""Download a video using yt-dlp for visual processing.
|
||||
|
||||
Downloads the best quality up to 1080p. Uses separate video+audio streams
|
||||
@@ -2142,6 +2236,8 @@ def download_video(url: str, output_dir: str) -> str | None:
|
||||
Args:
|
||||
url: Video URL.
|
||||
output_dir: Directory to save the downloaded file.
|
||||
clip_start: Download from this time (seconds). None = beginning.
|
||||
clip_end: Download until this time (seconds). None = full video.
|
||||
|
||||
Returns:
|
||||
Path to downloaded video file, or None on failure.
|
||||
@@ -2156,13 +2252,30 @@ def download_video(url: str, output_dir: str) -> str | None:
|
||||
output_template = os.path.join(output_dir, "video.%(ext)s")
|
||||
|
||||
opts = {
|
||||
"format": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
|
||||
"format": (
|
||||
"bestvideo[height<=1080][vcodec^=avc1]+bestaudio/best[height<=1080][vcodec^=avc1]/"
|
||||
"bestvideo[height<=1080][vcodec^=h264]+bestaudio/best[height<=1080][vcodec^=h264]/"
|
||||
"bestvideo[height<=1080]+bestaudio/best[height<=1080]"
|
||||
),
|
||||
"merge_output_format": "mp4",
|
||||
"outtmpl": output_template,
|
||||
"quiet": True,
|
||||
"no_warnings": True,
|
||||
}
|
||||
|
||||
# Apply download_ranges for clip support (yt-dlp 2023.01.02+)
|
||||
if clip_start is not None or clip_end is not None:
|
||||
try:
|
||||
from yt_dlp.utils import download_range_func
|
||||
|
||||
ranges = [(clip_start or 0, clip_end or float("inf"))]
|
||||
opts["download_ranges"] = download_range_func(None, ranges)
|
||||
except (ImportError, TypeError):
|
||||
logger.warning(
|
||||
"yt-dlp version does not support download_ranges; "
|
||||
"downloading full video and relying on frame timestamp filtering"
|
||||
)
|
||||
|
||||
logger.info(f"Downloading video for visual extraction...")
|
||||
try:
|
||||
with yt_dlp.YoutubeDL(opts) as ydl:
|
||||
|
||||
@@ -440,6 +440,9 @@ async def scrape_video(
|
||||
visual_min_gap: float | None = None,
|
||||
visual_similarity: float | None = None,
|
||||
vision_ocr: bool = False,
|
||||
start_time: str | None = None,
|
||||
end_time: str | None = None,
|
||||
setup: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Scrape video content and build Claude skill.
|
||||
@@ -458,10 +461,19 @@ async def scrape_video(
|
||||
visual_min_gap: Minimum seconds between kept frames (default: 2.0)
|
||||
visual_similarity: Similarity threshold to skip duplicate frames 0.0-1.0 (default: 0.95)
|
||||
vision_ocr: Use vision model for OCR on extracted frames
|
||||
start_time: Start time for extraction (seconds, MM:SS, or HH:MM:SS). Single video only.
|
||||
end_time: End time for extraction (seconds, MM:SS, or HH:MM:SS). Single video only.
|
||||
setup: Auto-detect GPU and install visual extraction deps (PyTorch, easyocr, etc.)
|
||||
|
||||
Returns:
|
||||
Video scraping results with file paths.
|
||||
"""
|
||||
if setup:
|
||||
from skill_seekers.cli.video_setup import run_setup
|
||||
|
||||
rc = run_setup(interactive=False)
|
||||
return "Setup completed successfully." if rc == 0 else "Setup failed. Check logs."
|
||||
|
||||
args = {}
|
||||
if url:
|
||||
args["url"] = url
|
||||
@@ -477,6 +489,10 @@ async def scrape_video(
|
||||
args["languages"] = languages
|
||||
if from_json:
|
||||
args["from_json"] = from_json
|
||||
if start_time:
|
||||
args["start_time"] = start_time
|
||||
if end_time:
|
||||
args["end_time"] = end_time
|
||||
if visual:
|
||||
args["visual"] = visual
|
||||
if whisper_model:
|
||||
|
||||
@@ -378,10 +378,21 @@ async def scrape_video_tool(args: dict) -> list[TextContent]:
|
||||
- visual_min_gap (float, optional): Minimum seconds between kept frames (default: 2.0)
|
||||
- visual_similarity (float, optional): Similarity threshold to skip duplicate frames (default: 0.95)
|
||||
- vision_ocr (bool, optional): Use vision model for OCR on frames (default: False)
|
||||
- start_time (str, optional): Start time for extraction (seconds, MM:SS, or HH:MM:SS)
|
||||
- end_time (str, optional): End time for extraction (seconds, MM:SS, or HH:MM:SS)
|
||||
- setup (bool, optional): Auto-detect GPU and install visual extraction deps
|
||||
|
||||
Returns:
|
||||
List[TextContent]: Tool execution results
|
||||
"""
|
||||
# Handle --setup early exit
|
||||
if args.get("setup", False):
|
||||
from skill_seekers.cli.video_setup import run_setup
|
||||
|
||||
rc = run_setup(interactive=False)
|
||||
msg = "Setup completed successfully." if rc == 0 else "Setup failed. Check logs."
|
||||
return [TextContent(type="text", text=msg)]
|
||||
|
||||
url = args.get("url")
|
||||
video_file = args.get("video_file")
|
||||
playlist = args.get("playlist")
|
||||
@@ -395,6 +406,8 @@ async def scrape_video_tool(args: dict) -> list[TextContent]:
|
||||
visual_min_gap = args.get("visual_min_gap")
|
||||
visual_similarity = args.get("visual_similarity")
|
||||
vision_ocr = args.get("vision_ocr", False)
|
||||
start_time = args.get("start_time")
|
||||
end_time = args.get("end_time")
|
||||
|
||||
# Build command
|
||||
cmd = [sys.executable, str(CLI_DIR / "video_scraper.py")]
|
||||
@@ -440,6 +453,10 @@ async def scrape_video_tool(args: dict) -> list[TextContent]:
|
||||
cmd.extend(["--visual-similarity", str(visual_similarity)])
|
||||
if vision_ocr:
|
||||
cmd.append("--vision-ocr")
|
||||
if start_time:
|
||||
cmd.extend(["--start-time", str(start_time)])
|
||||
if end_time:
|
||||
cmd.extend(["--end-time", str(end_time)])
|
||||
|
||||
# Run video_scraper.py with streaming
|
||||
timeout = 600 # 10 minutes for video extraction
|
||||
|
||||
Reference in New Issue
Block a user