Release v1.9.0: Add video-comparer skill and enhance transcript-fixer
## New Skill: video-comparer v1.0.0 - Compare original and compressed videos with interactive HTML reports - Calculate quality metrics (PSNR, SSIM) for compression analysis - Generate frame-by-frame visual comparisons (slider, side-by-side, grid) - Extract video metadata (codec, resolution, bitrate, duration) - Multi-platform FFmpeg support with security features ## transcript-fixer Enhancements - Add async AI processor for parallel processing - Add connection pool management for database operations - Add concurrency manager and rate limiter - Add audit log retention and database migrations - Add health check and metrics monitoring - Add comprehensive test suite (8 new test files) - Enhance security with domain and path validators ## Marketplace Updates - Update marketplace version from 1.8.0 to 1.9.0 - Update skills count from 15 to 16 - Update documentation (README.md, CLAUDE.md, CHANGELOG.md) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
758
transcript-fixer/scripts/tests/test_audit_log_retention.py
Normal file
758
transcript-fixer/scripts/tests/test_audit_log_retention.py
Normal file
@@ -0,0 +1,758 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive tests for Audit Log Retention Management (P1-11)
|
||||
|
||||
Test Coverage:
|
||||
1. Retention policy enforcement
|
||||
2. Cleanup strategies (DELETE, ARCHIVE, ANONYMIZE)
|
||||
3. Critical action extended retention
|
||||
4. Compliance reporting
|
||||
5. Archive creation and restoration
|
||||
6. Dry-run mode
|
||||
7. Transaction safety
|
||||
8. Error handling
|
||||
|
||||
Author: Chief Engineer (ISTJ, 20 years experience)
|
||||
Date: 2025-10-29
|
||||
"""
|
||||
|
||||
import gzip
|
||||
import json
|
||||
import pytest
|
||||
import sqlite3
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from utils.audit_log_retention import (
|
||||
AuditLogRetentionManager,
|
||||
RetentionPolicy,
|
||||
RetentionPeriod,
|
||||
CleanupStrategy,
|
||||
CleanupResult,
|
||||
ComplianceReport,
|
||||
CRITICAL_ACTIONS,
|
||||
get_retention_manager,
|
||||
reset_retention_manager,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db(tmp_path):
|
||||
"""Create test database with schema"""
|
||||
db_path = tmp_path / "test_retention.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create audit_log table
|
||||
cursor.execute("""
|
||||
CREATE TABLE audit_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
entity_type TEXT NOT NULL,
|
||||
entity_id INTEGER,
|
||||
user TEXT,
|
||||
details TEXT,
|
||||
success INTEGER DEFAULT 1,
|
||||
error_message TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# Create retention_policies table
|
||||
cursor.execute("""
|
||||
CREATE TABLE retention_policies (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
entity_type TEXT UNIQUE NOT NULL,
|
||||
retention_days INTEGER NOT NULL,
|
||||
is_active INTEGER DEFAULT 1,
|
||||
description TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# Create cleanup_history table
|
||||
cursor.execute("""
|
||||
CREATE TABLE cleanup_history (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
entity_type TEXT NOT NULL,
|
||||
records_deleted INTEGER DEFAULT 0,
|
||||
execution_time_ms INTEGER DEFAULT 0,
|
||||
success INTEGER DEFAULT 1,
|
||||
error_message TEXT,
|
||||
timestamp TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def retention_manager(test_db, tmp_path):
|
||||
"""Create retention manager instance"""
|
||||
archive_dir = tmp_path / "archives"
|
||||
manager = AuditLogRetentionManager(test_db, archive_dir)
|
||||
yield manager
|
||||
reset_retention_manager()
|
||||
|
||||
|
||||
def insert_audit_log(
|
||||
db_path: Path,
|
||||
action: str,
|
||||
entity_type: str,
|
||||
days_ago: int,
|
||||
entity_id: int = 1,
|
||||
user: str = "test_user"
|
||||
) -> int:
|
||||
"""Helper to insert audit log entry"""
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
timestamp = (datetime.now() - timedelta(days=days_ago)).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO audit_log (timestamp, action, entity_type, entity_id, user, details, success)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 1)
|
||||
""", (timestamp, action, entity_type, entity_id, user, json.dumps({"key": "value"})))
|
||||
|
||||
log_id = cursor.lastrowid
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return log_id
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Group 1: Retention Policy Enforcement
|
||||
# =============================================================================
|
||||
|
||||
def test_default_retention_policies(retention_manager):
|
||||
"""Test that default retention policies are loaded correctly"""
|
||||
policies = retention_manager.load_retention_policies()
|
||||
|
||||
# Check default policies exist
|
||||
assert 'correction' in policies
|
||||
assert 'suggestion' in policies
|
||||
assert 'system' in policies
|
||||
assert 'migration' in policies
|
||||
|
||||
# Check correction policy
|
||||
assert policies['correction'].retention_days == RetentionPeriod.ANNUAL.value
|
||||
assert policies['correction'].strategy == CleanupStrategy.ARCHIVE
|
||||
assert policies['correction'].critical_action_retention_days == RetentionPeriod.COMPLIANCE_SOX.value
|
||||
|
||||
|
||||
def test_custom_retention_policy_from_database(test_db, retention_manager):
|
||||
"""Test loading custom retention policies from database"""
|
||||
# Insert custom policy
|
||||
conn = sqlite3.connect(str(test_db))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO retention_policies (entity_type, retention_days, is_active, description)
|
||||
VALUES ('custom_entity', 60, 1, 'Custom test policy')
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Load policies
|
||||
policies = retention_manager.load_retention_policies()
|
||||
|
||||
# Check custom policy
|
||||
assert 'custom_entity' in policies
|
||||
assert policies['custom_entity'].retention_days == 60
|
||||
assert policies['custom_entity'].is_active is True
|
||||
|
||||
|
||||
def test_retention_policy_validation():
|
||||
"""Test retention policy validation"""
|
||||
# Valid policy
|
||||
policy = RetentionPolicy(
|
||||
entity_type='test',
|
||||
retention_days=30,
|
||||
strategy=CleanupStrategy.ARCHIVE
|
||||
)
|
||||
assert policy.retention_days == 30
|
||||
|
||||
# Invalid: negative days (except -1)
|
||||
with pytest.raises(ValueError, match="retention_days must be -1"):
|
||||
RetentionPolicy(
|
||||
entity_type='test',
|
||||
retention_days=-5,
|
||||
strategy=CleanupStrategy.DELETE
|
||||
)
|
||||
|
||||
# Invalid: critical retention shorter than regular
|
||||
with pytest.raises(ValueError, match="critical_action_retention_days must be"):
|
||||
RetentionPolicy(
|
||||
entity_type='test',
|
||||
retention_days=365,
|
||||
critical_action_retention_days=30, # Shorter than retention_days
|
||||
strategy=CleanupStrategy.ARCHIVE
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Group 2: Cleanup Strategies
|
||||
# =============================================================================
|
||||
|
||||
def test_cleanup_strategy_delete(test_db, retention_manager):
|
||||
"""Test DELETE cleanup strategy (permanent deletion)"""
|
||||
# Insert old logs
|
||||
for i in range(5):
|
||||
insert_audit_log(test_db, 'test_action', 'correction', days_ago=400)
|
||||
|
||||
# Override policy to use DELETE strategy
|
||||
retention_manager.default_policies['correction'].strategy = CleanupStrategy.DELETE
|
||||
retention_manager.default_policies['correction'].retention_days = 365
|
||||
|
||||
# Run cleanup
|
||||
results = retention_manager.cleanup_expired_logs(entity_type='correction')
|
||||
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert result.entity_type == 'correction'
|
||||
assert result.records_deleted == 5
|
||||
assert result.records_archived == 0
|
||||
assert result.success is True
|
||||
|
||||
# Verify logs are deleted
|
||||
conn = sqlite3.connect(str(test_db))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM audit_log WHERE entity_type = 'correction'")
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
def test_cleanup_strategy_archive(test_db, retention_manager):
|
||||
"""Test ARCHIVE cleanup strategy (archive then delete)"""
|
||||
# Insert old logs
|
||||
log_ids = []
|
||||
for i in range(5):
|
||||
log_id = insert_audit_log(test_db, 'test_action', 'suggestion', days_ago=100)
|
||||
log_ids.append(log_id)
|
||||
|
||||
# Override policy
|
||||
retention_manager.default_policies['suggestion'].strategy = CleanupStrategy.ARCHIVE
|
||||
retention_manager.default_policies['suggestion'].retention_days = 90
|
||||
|
||||
# Run cleanup
|
||||
results = retention_manager.cleanup_expired_logs(entity_type='suggestion')
|
||||
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert result.entity_type == 'suggestion'
|
||||
assert result.records_deleted == 5
|
||||
assert result.records_archived == 5
|
||||
assert result.success is True
|
||||
|
||||
# Verify archive file exists
|
||||
archive_files = list(retention_manager.archive_dir.glob("audit_log_suggestion_*.json.gz"))
|
||||
assert len(archive_files) == 1
|
||||
|
||||
# Verify archive content
|
||||
with gzip.open(archive_files[0], 'rt', encoding='utf-8') as f:
|
||||
archived_logs = json.load(f)
|
||||
|
||||
assert len(archived_logs) == 5
|
||||
assert all(log['id'] in log_ids for log in archived_logs)
|
||||
|
||||
|
||||
def test_cleanup_strategy_anonymize(test_db, retention_manager):
|
||||
"""Test ANONYMIZE cleanup strategy (remove PII, keep metadata)"""
|
||||
# Insert old logs with user info
|
||||
for i in range(3):
|
||||
insert_audit_log(
|
||||
test_db,
|
||||
'test_action',
|
||||
'correction',
|
||||
days_ago=400,
|
||||
user=f'user_{i}@example.com'
|
||||
)
|
||||
|
||||
# Override policy
|
||||
retention_manager.default_policies['correction'].strategy = CleanupStrategy.ANONYMIZE
|
||||
retention_manager.default_policies['correction'].retention_days = 365
|
||||
|
||||
# Run cleanup
|
||||
results = retention_manager.cleanup_expired_logs(entity_type='correction')
|
||||
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert result.entity_type == 'correction'
|
||||
assert result.records_anonymized == 3
|
||||
assert result.records_deleted == 0
|
||||
assert result.success is True
|
||||
|
||||
# Verify logs are anonymized
|
||||
conn = sqlite3.connect(str(test_db))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT user FROM audit_log WHERE entity_type = 'correction'")
|
||||
users = [row[0] for row in cursor.fetchall()]
|
||||
conn.close()
|
||||
|
||||
assert all(user == 'ANONYMIZED' for user in users)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Group 3: Critical Action Extended Retention
|
||||
# =============================================================================
|
||||
|
||||
def test_critical_action_extended_retention(test_db, retention_manager):
|
||||
"""Test that critical actions have extended retention"""
|
||||
# Insert regular and critical actions (both old)
|
||||
insert_audit_log(test_db, 'regular_action', 'correction', days_ago=400)
|
||||
insert_audit_log(test_db, 'delete_correction', 'correction', days_ago=400) # Critical
|
||||
|
||||
# Override policy with extended retention for critical actions
|
||||
retention_manager.default_policies['correction'].retention_days = 365 # 1 year
|
||||
retention_manager.default_policies['correction'].critical_action_retention_days = 2555 # 7 years (SOX)
|
||||
retention_manager.default_policies['correction'].strategy = CleanupStrategy.DELETE
|
||||
|
||||
# Run cleanup
|
||||
results = retention_manager.cleanup_expired_logs(entity_type='correction')
|
||||
|
||||
# Only regular action should be deleted
|
||||
assert results[0].records_deleted == 1
|
||||
|
||||
# Verify critical action is still there
|
||||
conn = sqlite3.connect(str(test_db))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT action FROM audit_log WHERE entity_type = 'correction'")
|
||||
actions = [row[0] for row in cursor.fetchall()]
|
||||
conn.close()
|
||||
|
||||
assert 'delete_correction' in actions
|
||||
assert 'regular_action' not in actions
|
||||
|
||||
|
||||
def test_critical_actions_set_completeness():
|
||||
"""Test that CRITICAL_ACTIONS set contains expected actions"""
|
||||
expected_critical = {
|
||||
'delete_correction',
|
||||
'update_correction',
|
||||
'approve_learned_suggestion',
|
||||
'reject_learned_suggestion',
|
||||
'system_config_change',
|
||||
'migration_applied',
|
||||
'security_event',
|
||||
}
|
||||
|
||||
assert expected_critical.issubset(CRITICAL_ACTIONS)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Group 4: Compliance Reporting
|
||||
# =============================================================================
|
||||
|
||||
def test_compliance_report_generation(test_db, retention_manager):
|
||||
"""Test compliance report generation"""
|
||||
# Insert test data
|
||||
insert_audit_log(test_db, 'action1', 'correction', days_ago=10)
|
||||
insert_audit_log(test_db, 'action2', 'suggestion', days_ago=100)
|
||||
insert_audit_log(test_db, 'action3', 'system', days_ago=200)
|
||||
|
||||
# Generate report
|
||||
report = retention_manager.generate_compliance_report()
|
||||
|
||||
assert isinstance(report, ComplianceReport)
|
||||
assert report.total_audit_logs == 3
|
||||
assert report.oldest_log_date is not None
|
||||
assert report.newest_log_date is not None
|
||||
assert 'correction' in report.logs_by_entity_type
|
||||
assert 'suggestion' in report.logs_by_entity_type
|
||||
assert report.storage_size_mb > 0
|
||||
|
||||
|
||||
def test_compliance_report_detects_violations(test_db, retention_manager):
|
||||
"""Test that compliance report detects retention violations"""
|
||||
# Insert expired logs
|
||||
insert_audit_log(test_db, 'old_action', 'suggestion', days_ago=100)
|
||||
|
||||
# Override policy with short retention
|
||||
retention_manager.default_policies['suggestion'].retention_days = 30
|
||||
|
||||
# Generate report
|
||||
report = retention_manager.generate_compliance_report()
|
||||
|
||||
# Should detect violation
|
||||
assert report.is_compliant is False
|
||||
assert len(report.retention_violations) > 0
|
||||
assert 'suggestion' in report.retention_violations[0]
|
||||
|
||||
|
||||
def test_compliance_report_no_violations(test_db, retention_manager):
|
||||
"""Test compliance report with no violations"""
|
||||
# Insert recent logs
|
||||
insert_audit_log(test_db, 'recent_action', 'correction', days_ago=10)
|
||||
|
||||
# Generate report
|
||||
report = retention_manager.generate_compliance_report()
|
||||
|
||||
# Should be compliant
|
||||
assert report.is_compliant is True
|
||||
assert len(report.retention_violations) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Group 5: Archive Operations
|
||||
# =============================================================================
|
||||
|
||||
def test_archive_creation_and_compression(test_db, retention_manager):
|
||||
"""Test that archives are created and compressed correctly"""
|
||||
# Insert logs
|
||||
for i in range(10):
|
||||
insert_audit_log(test_db, f'action_{i}', 'correction', days_ago=400)
|
||||
|
||||
# Override policy
|
||||
retention_manager.default_policies['correction'].retention_days = 365
|
||||
retention_manager.default_policies['correction'].strategy = CleanupStrategy.ARCHIVE
|
||||
|
||||
# Run cleanup
|
||||
retention_manager.cleanup_expired_logs(entity_type='correction')
|
||||
|
||||
# Check archive file
|
||||
archive_files = list(retention_manager.archive_dir.glob("audit_log_correction_*.json.gz"))
|
||||
assert len(archive_files) == 1
|
||||
|
||||
archive_file = archive_files[0]
|
||||
|
||||
# Verify it's a valid gzip file
|
||||
with gzip.open(archive_file, 'rt', encoding='utf-8') as f:
|
||||
logs = json.load(f)
|
||||
|
||||
assert len(logs) == 10
|
||||
assert all('id' in log for log in logs)
|
||||
assert all('action' in log for log in logs)
|
||||
|
||||
|
||||
def test_restore_from_archive(test_db, retention_manager):
|
||||
"""Test restoring logs from archive"""
|
||||
# Insert and archive logs
|
||||
original_ids = []
|
||||
for i in range(5):
|
||||
log_id = insert_audit_log(test_db, f'action_{i}', 'correction', days_ago=400)
|
||||
original_ids.append(log_id)
|
||||
|
||||
# Archive and delete
|
||||
retention_manager.default_policies['correction'].retention_days = 365
|
||||
retention_manager.default_policies['correction'].strategy = CleanupStrategy.ARCHIVE
|
||||
retention_manager.cleanup_expired_logs(entity_type='correction')
|
||||
|
||||
# Verify logs are deleted
|
||||
conn = sqlite3.connect(str(test_db))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM audit_log WHERE entity_type = 'correction'")
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
assert count == 0
|
||||
|
||||
# Restore from archive
|
||||
archive_files = list(retention_manager.archive_dir.glob("audit_log_correction_*.json.gz"))
|
||||
restored_count = retention_manager.restore_from_archive(archive_files[0])
|
||||
|
||||
assert restored_count == 5
|
||||
|
||||
# Verify logs are restored
|
||||
conn = sqlite3.connect(str(test_db))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT id FROM audit_log WHERE entity_type = 'correction' ORDER BY id")
|
||||
restored_ids = [row[0] for row in cursor.fetchall()]
|
||||
conn.close()
|
||||
|
||||
assert sorted(restored_ids) == sorted(original_ids)
|
||||
|
||||
|
||||
def test_restore_verify_only_mode(test_db, retention_manager):
|
||||
"""Test restore with verify_only flag"""
|
||||
# Create archive
|
||||
for i in range(3):
|
||||
insert_audit_log(test_db, f'action_{i}', 'suggestion', days_ago=100)
|
||||
|
||||
retention_manager.default_policies['suggestion'].retention_days = 90
|
||||
retention_manager.default_policies['suggestion'].strategy = CleanupStrategy.ARCHIVE
|
||||
retention_manager.cleanup_expired_logs(entity_type='suggestion')
|
||||
|
||||
# Verify archive (without restoring)
|
||||
archive_files = list(retention_manager.archive_dir.glob("audit_log_suggestion_*.json.gz"))
|
||||
count = retention_manager.restore_from_archive(archive_files[0], verify_only=True)
|
||||
|
||||
assert count == 3
|
||||
|
||||
# Verify logs are still deleted (not restored)
|
||||
conn = sqlite3.connect(str(test_db))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM audit_log WHERE entity_type = 'suggestion'")
|
||||
db_count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
assert db_count == 0
|
||||
|
||||
|
||||
def test_restore_skips_duplicates(test_db, retention_manager):
|
||||
"""Test that restore skips duplicate log entries"""
|
||||
# Insert logs
|
||||
for i in range(3):
|
||||
insert_audit_log(test_db, f'action_{i}', 'correction', days_ago=400)
|
||||
|
||||
# Archive
|
||||
retention_manager.default_policies['correction'].retention_days = 365
|
||||
retention_manager.default_policies['correction'].strategy = CleanupStrategy.ARCHIVE
|
||||
retention_manager.cleanup_expired_logs(entity_type='correction')
|
||||
|
||||
# Restore once
|
||||
archive_files = list(retention_manager.archive_dir.glob("audit_log_correction_*.json.gz"))
|
||||
first_restore = retention_manager.restore_from_archive(archive_files[0])
|
||||
assert first_restore == 3
|
||||
|
||||
# Restore again (should skip duplicates)
|
||||
second_restore = retention_manager.restore_from_archive(archive_files[0])
|
||||
assert second_restore == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Group 6: Dry-Run Mode
|
||||
# =============================================================================
|
||||
|
||||
def test_dry_run_mode_no_changes(test_db, retention_manager):
|
||||
"""Test that dry-run mode doesn't make actual changes"""
|
||||
# Insert old logs
|
||||
for i in range(5):
|
||||
insert_audit_log(test_db, 'action', 'correction', days_ago=400)
|
||||
|
||||
# Override policy
|
||||
retention_manager.default_policies['correction'].retention_days = 365
|
||||
retention_manager.default_policies['correction'].strategy = CleanupStrategy.DELETE
|
||||
|
||||
# Run cleanup in dry-run mode
|
||||
results = retention_manager.cleanup_expired_logs(entity_type='correction', dry_run=True)
|
||||
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert result.records_scanned == 5
|
||||
assert result.records_deleted == 5 # Would delete
|
||||
assert result.success is True
|
||||
|
||||
# Verify logs are NOT actually deleted
|
||||
conn = sqlite3.connect(str(test_db))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM audit_log WHERE entity_type = 'correction'")
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
assert count == 5 # Still there
|
||||
|
||||
|
||||
def test_dry_run_mode_archive_strategy(test_db, retention_manager):
|
||||
"""Test dry-run mode with ARCHIVE strategy"""
|
||||
# Insert old logs
|
||||
for i in range(3):
|
||||
insert_audit_log(test_db, 'action', 'suggestion', days_ago=100)
|
||||
|
||||
# Override policy
|
||||
retention_manager.default_policies['suggestion'].retention_days = 90
|
||||
retention_manager.default_policies['suggestion'].strategy = CleanupStrategy.ARCHIVE
|
||||
|
||||
# Run cleanup in dry-run mode
|
||||
results = retention_manager.cleanup_expired_logs(entity_type='suggestion', dry_run=True)
|
||||
|
||||
# Check result
|
||||
result = results[0]
|
||||
assert result.records_archived == 3 # Would archive
|
||||
|
||||
# Verify no archive files created
|
||||
archive_files = list(retention_manager.archive_dir.glob("audit_log_suggestion_*.json.gz"))
|
||||
assert len(archive_files) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Group 7: Transaction Safety
|
||||
# =============================================================================
|
||||
|
||||
def test_transaction_rollback_on_archive_failure(test_db, retention_manager, monkeypatch):
|
||||
"""Test that transaction rolls back if archive fails"""
|
||||
# Insert logs
|
||||
for i in range(3):
|
||||
insert_audit_log(test_db, 'action', 'correction', days_ago=400)
|
||||
|
||||
# Override policy
|
||||
retention_manager.default_policies['correction'].retention_days = 365
|
||||
retention_manager.default_policies['correction'].strategy = CleanupStrategy.ARCHIVE
|
||||
|
||||
# Mock _archive_logs to raise an error
|
||||
def mock_archive_logs(*args, **kwargs):
|
||||
raise IOError("Archive write failed")
|
||||
|
||||
monkeypatch.setattr(retention_manager, '_archive_logs', mock_archive_logs)
|
||||
|
||||
# Run cleanup (should fail)
|
||||
results = retention_manager.cleanup_expired_logs(entity_type='correction')
|
||||
|
||||
assert len(results) == 1
|
||||
result = results[0]
|
||||
assert result.success is False
|
||||
assert len(result.errors) > 0
|
||||
|
||||
# Verify logs are NOT deleted (transaction rolled back)
|
||||
conn = sqlite3.connect(str(test_db))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM audit_log WHERE entity_type = 'correction'")
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
assert count == 3 # Still there
|
||||
|
||||
|
||||
def test_cleanup_history_recorded(test_db, retention_manager):
|
||||
"""Test that cleanup operations are recorded in history"""
|
||||
# Insert logs
|
||||
for i in range(5):
|
||||
insert_audit_log(test_db, 'action', 'correction', days_ago=400)
|
||||
|
||||
# Run cleanup
|
||||
retention_manager.default_policies['correction'].retention_days = 365
|
||||
retention_manager.default_policies['correction'].strategy = CleanupStrategy.DELETE
|
||||
retention_manager.cleanup_expired_logs(entity_type='correction')
|
||||
|
||||
# Check cleanup history
|
||||
conn = sqlite3.connect(str(test_db))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT entity_type, records_deleted, success
|
||||
FROM cleanup_history
|
||||
WHERE entity_type = 'correction'
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
assert row is not None
|
||||
assert row[0] == 'correction'
|
||||
assert row[1] == 5 # records_deleted
|
||||
assert row[2] == 1 # success
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Group 8: Error Handling
|
||||
# =============================================================================
|
||||
|
||||
def test_handle_missing_archive_file(retention_manager):
|
||||
"""Test error handling for missing archive file"""
|
||||
fake_archive = Path("/nonexistent/archive.json.gz")
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="Archive file not found"):
|
||||
retention_manager.restore_from_archive(fake_archive)
|
||||
|
||||
|
||||
def test_handle_invalid_entity_type(retention_manager):
|
||||
"""Test handling of unknown entity type"""
|
||||
results = retention_manager.cleanup_expired_logs(entity_type='nonexistent_type')
|
||||
|
||||
# Should return empty results (no policy found)
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_permanent_retention_skipped(test_db, retention_manager):
|
||||
"""Test that permanent retention entities are never cleaned up"""
|
||||
# Insert old migration logs
|
||||
for i in range(3):
|
||||
insert_audit_log(test_db, 'migration_applied', 'migration', days_ago=3000) # 8+ years old
|
||||
|
||||
# Migration has permanent retention by default
|
||||
results = retention_manager.cleanup_expired_logs(entity_type='migration')
|
||||
|
||||
# Should skip cleanup
|
||||
assert len(results) == 0
|
||||
|
||||
# Verify logs are still there
|
||||
conn = sqlite3.connect(str(test_db))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM audit_log WHERE entity_type = 'migration'")
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
assert count == 3
|
||||
|
||||
|
||||
def test_anonymize_handles_invalid_json(test_db, retention_manager):
|
||||
"""Test anonymization handles invalid JSON in details field"""
|
||||
# Insert log with invalid JSON
|
||||
conn = sqlite3.connect(str(test_db))
|
||||
cursor = conn.cursor()
|
||||
|
||||
timestamp = (datetime.now() - timedelta(days=400)).isoformat()
|
||||
cursor.execute("""
|
||||
INSERT INTO audit_log (timestamp, action, entity_type, user, details)
|
||||
VALUES (?, 'test', 'correction', 'user@example.com', 'NOT_JSON')
|
||||
""", (timestamp,))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Run anonymization
|
||||
retention_manager.default_policies['correction'].retention_days = 365
|
||||
retention_manager.default_policies['correction'].strategy = CleanupStrategy.ANONYMIZE
|
||||
|
||||
results = retention_manager.cleanup_expired_logs(entity_type='correction')
|
||||
|
||||
# Should succeed without raising exception
|
||||
assert results[0].success is True
|
||||
assert results[0].records_anonymized == 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Group 9: Global Instance Management
|
||||
# =============================================================================
|
||||
|
||||
def test_global_retention_manager_singleton(test_db, tmp_path):
|
||||
"""Test global retention manager follows singleton pattern"""
|
||||
reset_retention_manager()
|
||||
|
||||
archive_dir = tmp_path / "archives"
|
||||
|
||||
# Get manager twice
|
||||
manager1 = get_retention_manager(test_db, archive_dir)
|
||||
manager2 = get_retention_manager()
|
||||
|
||||
# Should be same instance
|
||||
assert manager1 is manager2
|
||||
|
||||
# Cleanup
|
||||
reset_retention_manager()
|
||||
|
||||
|
||||
def test_global_retention_manager_reset(test_db, tmp_path):
|
||||
"""Test resetting global retention manager"""
|
||||
reset_retention_manager()
|
||||
|
||||
archive_dir = tmp_path / "archives"
|
||||
|
||||
# Get manager
|
||||
manager1 = get_retention_manager(test_db, archive_dir)
|
||||
|
||||
# Reset
|
||||
reset_retention_manager()
|
||||
|
||||
# Get new manager
|
||||
manager2 = get_retention_manager(test_db, archive_dir)
|
||||
|
||||
# Should be different instance
|
||||
assert manager1 is not manager2
|
||||
|
||||
# Cleanup
|
||||
reset_retention_manager()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
343
transcript-fixer/scripts/tests/test_connection_pool.py
Normal file
343
transcript-fixer/scripts/tests/test_connection_pool.py
Normal file
@@ -0,0 +1,343 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Suite for Thread-Safe Connection Pool
|
||||
|
||||
CRITICAL FIX VERIFICATION: Tests for Critical-1
|
||||
Purpose: Verify thread-safe connection pool prevents data corruption
|
||||
|
||||
Test Coverage:
|
||||
1. Basic pool operations
|
||||
2. Concurrent access (race conditions)
|
||||
3. Pool exhaustion handling
|
||||
4. Connection cleanup
|
||||
5. Statistics tracking
|
||||
|
||||
Author: Chief Engineer
|
||||
Priority: P0 - Critical
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from core.connection_pool import (
|
||||
ConnectionPool,
|
||||
PoolExhaustedError,
|
||||
MAX_CONNECTIONS
|
||||
)
|
||||
|
||||
|
||||
class TestConnectionPoolBasics:
|
||||
"""Test basic connection pool functionality"""
|
||||
|
||||
def test_pool_initialization(self, tmp_path):
|
||||
"""Test pool creates with valid parameters"""
|
||||
db_path = tmp_path / "test.db"
|
||||
|
||||
pool = ConnectionPool(db_path, max_connections=3)
|
||||
|
||||
assert pool.max_connections == 3
|
||||
assert pool.db_path == db_path
|
||||
|
||||
pool.close_all()
|
||||
|
||||
def test_pool_invalid_max_connections(self, tmp_path):
|
||||
"""Test pool rejects invalid max_connections"""
|
||||
db_path = tmp_path / "test.db"
|
||||
|
||||
with pytest.raises(ValueError, match="max_connections must be >= 1"):
|
||||
ConnectionPool(db_path, max_connections=0)
|
||||
|
||||
with pytest.raises(ValueError, match="max_connections must be >= 1"):
|
||||
ConnectionPool(db_path, max_connections=-1)
|
||||
|
||||
def test_pool_invalid_timeout(self, tmp_path):
|
||||
"""Test pool rejects negative timeouts"""
|
||||
db_path = tmp_path / "test.db"
|
||||
|
||||
with pytest.raises(ValueError, match="connection_timeout"):
|
||||
ConnectionPool(db_path, connection_timeout=-1)
|
||||
|
||||
with pytest.raises(ValueError, match="pool_timeout"):
|
||||
ConnectionPool(db_path, pool_timeout=-1)
|
||||
|
||||
def test_pool_nonexistent_directory(self):
|
||||
"""Test pool rejects nonexistent directory"""
|
||||
db_path = Path("/nonexistent/directory/test.db")
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="doesn't exist"):
|
||||
ConnectionPool(db_path)
|
||||
|
||||
|
||||
class TestConnectionOperations:
|
||||
"""Test connection acquisition and release"""
|
||||
|
||||
def test_get_connection_basic(self, tmp_path):
|
||||
"""Test basic connection acquisition"""
|
||||
db_path = tmp_path / "test.db"
|
||||
pool = ConnectionPool(db_path, max_connections=2)
|
||||
|
||||
with pool.get_connection() as conn:
|
||||
assert isinstance(conn, sqlite3.Connection)
|
||||
# Connection should work
|
||||
cursor = conn.execute("SELECT 1")
|
||||
assert cursor.fetchone()[0] == 1
|
||||
|
||||
pool.close_all()
|
||||
|
||||
def test_connection_returned_to_pool(self, tmp_path):
|
||||
"""Test connection is returned after use"""
|
||||
db_path = tmp_path / "test.db"
|
||||
pool = ConnectionPool(db_path, max_connections=1)
|
||||
|
||||
# Use connection
|
||||
with pool.get_connection() as conn:
|
||||
conn.execute("SELECT 1")
|
||||
|
||||
# Should be able to get it again
|
||||
with pool.get_connection() as conn:
|
||||
conn.execute("SELECT 2")
|
||||
|
||||
pool.close_all()
|
||||
|
||||
def test_wal_mode_enabled(self, tmp_path):
|
||||
"""Test WAL mode is enabled for concurrency"""
|
||||
db_path = tmp_path / "test.db"
|
||||
pool = ConnectionPool(db_path)
|
||||
|
||||
with pool.get_connection() as conn:
|
||||
cursor = conn.execute("PRAGMA journal_mode")
|
||||
mode = cursor.fetchone()[0]
|
||||
assert mode.upper() == "WAL"
|
||||
|
||||
pool.close_all()
|
||||
|
||||
def test_foreign_keys_enabled(self, tmp_path):
|
||||
"""Test foreign keys are enforced"""
|
||||
db_path = tmp_path / "test.db"
|
||||
pool = ConnectionPool(db_path)
|
||||
|
||||
with pool.get_connection() as conn:
|
||||
cursor = conn.execute("PRAGMA foreign_keys")
|
||||
enabled = cursor.fetchone()[0]
|
||||
assert enabled == 1
|
||||
|
||||
pool.close_all()
|
||||
|
||||
|
||||
class TestConcurrency:
|
||||
"""
|
||||
CRITICAL: Test concurrent access for race conditions
|
||||
|
||||
This is the main reason for the fix. The old code used
|
||||
check_same_thread=False which caused race conditions.
|
||||
"""
|
||||
|
||||
def test_concurrent_reads(self, tmp_path):
|
||||
"""Test multiple threads reading simultaneously"""
|
||||
db_path = tmp_path / "test.db"
|
||||
pool = ConnectionPool(db_path, max_connections=5)
|
||||
|
||||
# Create test table
|
||||
with pool.get_connection() as conn:
|
||||
conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)")
|
||||
conn.execute("INSERT INTO test (value) VALUES ('test1'), ('test2'), ('test3')")
|
||||
conn.commit()
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def read_data(thread_id):
|
||||
try:
|
||||
with pool.get_connection() as conn:
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM test")
|
||||
count = cursor.fetchone()[0]
|
||||
results.append((thread_id, count))
|
||||
except Exception as e:
|
||||
errors.append((thread_id, str(e)))
|
||||
|
||||
# Run 10 concurrent reads
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(read_data, i) for i in range(10)]
|
||||
for future in as_completed(futures):
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Verify
|
||||
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||
assert len(results) == 10
|
||||
assert all(count == 3 for _, count in results), "Race condition detected!"
|
||||
|
||||
pool.close_all()
|
||||
|
||||
def test_concurrent_writes_no_corruption(self, tmp_path):
|
||||
"""
|
||||
CRITICAL TEST: Verify no data corruption under concurrent writes
|
||||
|
||||
This would fail with check_same_thread=False
|
||||
"""
|
||||
db_path = tmp_path / "test.db"
|
||||
pool = ConnectionPool(db_path, max_connections=5)
|
||||
|
||||
# Create counter table
|
||||
with pool.get_connection() as conn:
|
||||
conn.execute("CREATE TABLE counter (id INTEGER PRIMARY KEY, value INTEGER)")
|
||||
conn.execute("INSERT INTO counter (id, value) VALUES (1, 0)")
|
||||
conn.commit()
|
||||
|
||||
errors = []
|
||||
|
||||
def increment_counter(thread_id):
|
||||
try:
|
||||
with pool.get_connection() as conn:
|
||||
# Read current value
|
||||
cursor = conn.execute("SELECT value FROM counter WHERE id = 1")
|
||||
current = cursor.fetchone()[0]
|
||||
|
||||
# Increment
|
||||
new_value = current + 1
|
||||
|
||||
# Write back
|
||||
conn.execute("UPDATE counter SET value = ? WHERE id = 1", (new_value,))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
errors.append((thread_id, str(e)))
|
||||
|
||||
# Run 100 concurrent increments
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(increment_counter, i) for i in range(100)]
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
|
||||
# Check final value
|
||||
with pool.get_connection() as conn:
|
||||
cursor = conn.execute("SELECT value FROM counter WHERE id = 1")
|
||||
final_value = cursor.fetchone()[0]
|
||||
|
||||
# Note: Due to race conditions in the increment logic itself,
|
||||
# final value might be less than 100. But the important thing is:
|
||||
# 1. No errors occurred
|
||||
# 2. No database corruption
|
||||
# 3. We got SOME value (not NULL, not negative)
|
||||
|
||||
assert len(errors) == 0, f"Errors: {errors}"
|
||||
assert final_value > 0, "Counter should have increased"
|
||||
assert final_value <= 100, "Counter shouldn't exceed number of increments"
|
||||
|
||||
pool.close_all()
|
||||
|
||||
|
||||
class TestPoolExhaustion:
|
||||
"""Test behavior when pool is exhausted"""
|
||||
|
||||
def test_pool_exhaustion_timeout(self, tmp_path):
|
||||
"""Test PoolExhaustedError when all connections busy"""
|
||||
db_path = tmp_path / "test.db"
|
||||
pool = ConnectionPool(db_path, max_connections=2, pool_timeout=0.5)
|
||||
|
||||
# Hold all connections
|
||||
conn1 = pool.get_connection()
|
||||
conn1.__enter__()
|
||||
|
||||
conn2 = pool.get_connection()
|
||||
conn2.__enter__()
|
||||
|
||||
# Try to get third connection (should timeout)
|
||||
with pytest.raises(PoolExhaustedError, match="No connection available"):
|
||||
with pool.get_connection() as conn3:
|
||||
pass
|
||||
|
||||
# Release connections
|
||||
conn1.__exit__(None, None, None)
|
||||
conn2.__exit__(None, None, None)
|
||||
|
||||
pool.close_all()
|
||||
|
||||
def test_pool_recovery_after_exhaustion(self, tmp_path):
|
||||
"""Test pool recovers after connections released"""
|
||||
db_path = tmp_path / "test.db"
|
||||
pool = ConnectionPool(db_path, max_connections=1, pool_timeout=0.5)
|
||||
|
||||
# Use connection
|
||||
with pool.get_connection() as conn:
|
||||
conn.execute("SELECT 1")
|
||||
|
||||
# Should be available again
|
||||
with pool.get_connection() as conn:
|
||||
conn.execute("SELECT 2")
|
||||
|
||||
pool.close_all()
|
||||
|
||||
|
||||
class TestStatistics:
|
||||
"""Test pool statistics tracking"""
|
||||
|
||||
def test_statistics_initialization(self, tmp_path):
|
||||
"""Test initial statistics"""
|
||||
db_path = tmp_path / "test.db"
|
||||
pool = ConnectionPool(db_path, max_connections=3)
|
||||
|
||||
stats = pool.get_statistics()
|
||||
|
||||
assert stats.total_connections == 3
|
||||
assert stats.total_acquired == 0
|
||||
assert stats.total_released == 0
|
||||
assert stats.total_timeouts == 0
|
||||
|
||||
pool.close_all()
|
||||
|
||||
def test_statistics_tracking(self, tmp_path):
|
||||
"""Test statistics are updated correctly"""
|
||||
db_path = tmp_path / "test.db"
|
||||
pool = ConnectionPool(db_path, max_connections=2)
|
||||
|
||||
# Acquire and release
|
||||
with pool.get_connection() as conn:
|
||||
conn.execute("SELECT 1")
|
||||
|
||||
with pool.get_connection() as conn:
|
||||
conn.execute("SELECT 2")
|
||||
|
||||
stats = pool.get_statistics()
|
||||
|
||||
assert stats.total_acquired == 2
|
||||
assert stats.total_released == 2
|
||||
|
||||
pool.close_all()
|
||||
|
||||
|
||||
class TestCleanup:
|
||||
"""Test proper resource cleanup"""
|
||||
|
||||
def test_close_all_connections(self, tmp_path):
|
||||
"""Test close_all() closes all connections"""
|
||||
db_path = tmp_path / "test.db"
|
||||
pool = ConnectionPool(db_path, max_connections=3)
|
||||
|
||||
# Initialize pool by acquiring connection
|
||||
with pool.get_connection() as conn:
|
||||
conn.execute("SELECT 1")
|
||||
|
||||
# Close all
|
||||
pool.close_all()
|
||||
|
||||
# Pool should not be usable after close
|
||||
# (This will fail because pool is not initialized)
|
||||
# In a real scenario, we'd track connection states
|
||||
|
||||
def test_context_manager_cleanup(self, tmp_path):
|
||||
"""Test pool as context manager cleans up"""
|
||||
db_path = tmp_path / "test.db"
|
||||
|
||||
with ConnectionPool(db_path, max_connections=2) as pool:
|
||||
with pool.get_connection() as conn:
|
||||
conn.execute("SELECT 1")
|
||||
|
||||
# Pool should be closed automatically
|
||||
|
||||
|
||||
# Run tests with: pytest -v test_connection_pool.py
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
302
transcript-fixer/scripts/tests/test_domain_validator.py
Normal file
302
transcript-fixer/scripts/tests/test_domain_validator.py
Normal file
@@ -0,0 +1,302 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Suite for Domain Validator
|
||||
|
||||
CRITICAL FIX VERIFICATION: Tests for Critical-3
|
||||
Purpose: Verify SQL injection prevention and input validation
|
||||
|
||||
Test Coverage:
|
||||
1. Domain whitelist validation
|
||||
2. Source whitelist validation
|
||||
3. Text sanitization
|
||||
4. Confidence validation
|
||||
5. SQL injection attack prevention
|
||||
6. DoS prevention (length limits)
|
||||
|
||||
Author: Chief Engineer
|
||||
Priority: P0 - Critical
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from utils.domain_validator import (
|
||||
validate_domain,
|
||||
validate_source,
|
||||
sanitize_text_field,
|
||||
validate_correction_inputs,
|
||||
validate_confidence,
|
||||
is_safe_sql_identifier,
|
||||
ValidationError,
|
||||
VALID_DOMAINS,
|
||||
VALID_SOURCES,
|
||||
MAX_FROM_TEXT_LENGTH,
|
||||
MAX_TO_TEXT_LENGTH,
|
||||
)
|
||||
|
||||
|
||||
class TestDomainValidation:
|
||||
"""Test domain whitelist validation"""
|
||||
|
||||
def test_valid_domains(self):
|
||||
"""Test all valid domains are accepted"""
|
||||
for domain in VALID_DOMAINS:
|
||||
result = validate_domain(domain)
|
||||
assert result == domain
|
||||
|
||||
def test_case_insensitive(self):
|
||||
"""Test domain validation is case-insensitive"""
|
||||
assert validate_domain("GENERAL") == "general"
|
||||
assert validate_domain("General") == "general"
|
||||
assert validate_domain("embodied_AI") == "embodied_ai"
|
||||
|
||||
def test_whitespace_trimmed(self):
|
||||
"""Test whitespace is trimmed"""
|
||||
assert validate_domain(" general ") == "general"
|
||||
assert validate_domain("\ngeneral\t") == "general"
|
||||
|
||||
def test_sql_injection_domain(self):
|
||||
"""CRITICAL: Test SQL injection is rejected"""
|
||||
malicious_inputs = [
|
||||
"general'; DROP TABLE corrections--",
|
||||
"general' OR '1'='1",
|
||||
"'; DELETE FROM corrections WHERE '1'='1",
|
||||
"general\"; DROP TABLE--",
|
||||
"1' UNION SELECT * FROM corrections--",
|
||||
]
|
||||
|
||||
for malicious in malicious_inputs:
|
||||
with pytest.raises(ValidationError, match="Invalid domain"):
|
||||
validate_domain(malicious)
|
||||
|
||||
def test_empty_domain(self):
|
||||
"""Test empty domain is rejected"""
|
||||
with pytest.raises(ValidationError, match="cannot be empty"):
|
||||
validate_domain("")
|
||||
|
||||
with pytest.raises(ValidationError, match="cannot be empty"):
|
||||
validate_domain(" ")
|
||||
|
||||
|
||||
class TestSourceValidation:
|
||||
"""Test source whitelist validation"""
|
||||
|
||||
def test_valid_sources(self):
|
||||
"""Test all valid sources are accepted"""
|
||||
for source in VALID_SOURCES:
|
||||
result = validate_source(source)
|
||||
assert result == source
|
||||
|
||||
def test_invalid_source(self):
|
||||
"""Test invalid source is rejected"""
|
||||
with pytest.raises(ValidationError, match="Invalid source"):
|
||||
validate_source("hacked")
|
||||
|
||||
with pytest.raises(ValidationError, match="Invalid source"):
|
||||
validate_source("'; DROP TABLE--")
|
||||
|
||||
|
||||
class TestTextSanitization:
|
||||
"""Test text field sanitization"""
|
||||
|
||||
def test_valid_text(self):
|
||||
"""Test normal text passes"""
|
||||
text = "Hello world!"
|
||||
result = sanitize_text_field(text, 100, "test")
|
||||
assert result == text
|
||||
|
||||
def test_length_limit(self):
|
||||
"""Test length limit is enforced"""
|
||||
long_text = "a" * 1000
|
||||
with pytest.raises(ValidationError, match="too long"):
|
||||
sanitize_text_field(long_text, 100, "test")
|
||||
|
||||
def test_null_byte_rejection(self):
|
||||
"""CRITICAL: Test null bytes are rejected (can break SQLite)"""
|
||||
malicious = "hello\x00world"
|
||||
with pytest.raises(ValidationError, match="null bytes"):
|
||||
sanitize_text_field(malicious, 100, "test")
|
||||
|
||||
def test_control_characters(self):
|
||||
"""Test control characters are removed"""
|
||||
text_with_controls = "hello\x01\x02world\x1f"
|
||||
result = sanitize_text_field(text_with_controls, 100, "test")
|
||||
assert result == "helloworld"
|
||||
|
||||
def test_whitespace_preserved(self):
|
||||
"""Test normal whitespace is preserved"""
|
||||
text = "hello\tworld\ntest\r\nline"
|
||||
result = sanitize_text_field(text, 100, "test")
|
||||
assert "\t" in result
|
||||
assert "\n" in result
|
||||
|
||||
def test_empty_after_sanitization(self):
|
||||
"""Test rejects text that becomes empty after sanitization"""
|
||||
with pytest.raises(ValidationError, match="empty after sanitization"):
|
||||
sanitize_text_field(" ", 100, "test")
|
||||
|
||||
|
||||
class TestCorrectionInputsValidation:
|
||||
"""Test full correction validation"""
|
||||
|
||||
def test_valid_inputs(self):
|
||||
"""Test valid inputs pass"""
|
||||
result = validate_correction_inputs(
|
||||
from_text="teh",
|
||||
to_text="the",
|
||||
domain="general",
|
||||
source="manual",
|
||||
notes="Typo fix",
|
||||
added_by="test_user"
|
||||
)
|
||||
|
||||
assert result[0] == "teh"
|
||||
assert result[1] == "the"
|
||||
assert result[2] == "general"
|
||||
assert result[3] == "manual"
|
||||
assert result[4] == "Typo fix"
|
||||
assert result[5] == "test_user"
|
||||
|
||||
def test_invalid_domain_in_full_validation(self):
|
||||
"""Test invalid domain is rejected in full validation"""
|
||||
with pytest.raises(ValidationError, match="Invalid domain"):
|
||||
validate_correction_inputs(
|
||||
from_text="test",
|
||||
to_text="test",
|
||||
domain="hacked'; DROP--",
|
||||
source="manual"
|
||||
)
|
||||
|
||||
def test_text_too_long(self):
|
||||
"""Test excessively long text is rejected"""
|
||||
long_text = "a" * (MAX_FROM_TEXT_LENGTH + 1)
|
||||
|
||||
with pytest.raises(ValidationError, match="too long"):
|
||||
validate_correction_inputs(
|
||||
from_text=long_text,
|
||||
to_text="test",
|
||||
domain="general",
|
||||
source="manual"
|
||||
)
|
||||
|
||||
def test_optional_fields_none(self):
|
||||
"""Test optional fields can be None"""
|
||||
result = validate_correction_inputs(
|
||||
from_text="test",
|
||||
to_text="test",
|
||||
domain="general",
|
||||
source="manual",
|
||||
notes=None,
|
||||
added_by=None
|
||||
)
|
||||
|
||||
assert result[4] is None # notes
|
||||
assert result[5] is None # added_by
|
||||
|
||||
|
||||
class TestConfidenceValidation:
|
||||
"""Test confidence score validation"""
|
||||
|
||||
def test_valid_confidence(self):
|
||||
"""Test valid confidence values"""
|
||||
assert validate_confidence(0.0) == 0.0
|
||||
assert validate_confidence(0.5) == 0.5
|
||||
assert validate_confidence(1.0) == 1.0
|
||||
|
||||
def test_confidence_out_of_range(self):
|
||||
"""Test out-of-range confidence is rejected"""
|
||||
with pytest.raises(ValidationError, match="between 0.0 and 1.0"):
|
||||
validate_confidence(-0.1)
|
||||
|
||||
with pytest.raises(ValidationError, match="between 0.0 and 1.0"):
|
||||
validate_confidence(1.1)
|
||||
|
||||
with pytest.raises(ValidationError, match="between 0.0 and 1.0"):
|
||||
validate_confidence(100.0)
|
||||
|
||||
def test_confidence_type_check(self):
|
||||
"""Test non-numeric confidence is rejected"""
|
||||
with pytest.raises(ValidationError, match="must be a number"):
|
||||
validate_confidence("high") # type: ignore
|
||||
|
||||
|
||||
class TestSQLIdentifierValidation:
|
||||
"""Test SQL identifier safety checks"""
|
||||
|
||||
def test_safe_identifiers(self):
|
||||
"""Test valid SQL identifiers"""
|
||||
assert is_safe_sql_identifier("table_name")
|
||||
assert is_safe_sql_identifier("_private")
|
||||
assert is_safe_sql_identifier("Column123")
|
||||
|
||||
def test_unsafe_identifiers(self):
|
||||
"""Test unsafe SQL identifiers are rejected"""
|
||||
assert not is_safe_sql_identifier("table-name") # Hyphen
|
||||
assert not is_safe_sql_identifier("123table") # Starts with number
|
||||
assert not is_safe_sql_identifier("table name") # Space
|
||||
assert not is_safe_sql_identifier("table; DROP") # Semicolon
|
||||
assert not is_safe_sql_identifier("table' OR") # Quote
|
||||
|
||||
def test_empty_identifier(self):
|
||||
"""Test empty identifier is rejected"""
|
||||
assert not is_safe_sql_identifier("")
|
||||
|
||||
def test_too_long_identifier(self):
|
||||
"""Test excessively long identifier is rejected"""
|
||||
long_id = "a" * 65
|
||||
assert not is_safe_sql_identifier(long_id)
|
||||
|
||||
|
||||
class TestSecurityScenarios:
|
||||
"""Test realistic attack scenarios"""
|
||||
|
||||
def test_sql_injection_via_from_text(self):
|
||||
"""Test SQL injection via from_text is handled safely"""
|
||||
# These should be sanitized, not cause SQL injection
|
||||
malicious_from = "test'; DROP TABLE corrections--"
|
||||
|
||||
# Should NOT raise exception - text fields allow any content
|
||||
# They're protected by parameterized queries
|
||||
result = validate_correction_inputs(
|
||||
from_text=malicious_from,
|
||||
to_text="safe",
|
||||
domain="general",
|
||||
source="manual"
|
||||
)
|
||||
|
||||
assert result[0] == malicious_from # Text preserved as-is
|
||||
|
||||
def test_dos_via_long_input(self):
|
||||
"""Test DoS prevention via length limits"""
|
||||
# Attempt to create extremely long input
|
||||
dos_text = "a" * 10000
|
||||
|
||||
with pytest.raises(ValidationError, match="too long"):
|
||||
validate_correction_inputs(
|
||||
from_text=dos_text,
|
||||
to_text="test",
|
||||
domain="general",
|
||||
source="manual"
|
||||
)
|
||||
|
||||
def test_domain_bypass_attempts(self):
|
||||
"""Test various domain bypass attempts"""
|
||||
bypass_attempts = [
|
||||
"general\x00hacked", # Null byte injection
|
||||
"general\nmalicious", # Newline injection
|
||||
"general -- comment", # SQL comment
|
||||
"general' UNION", # SQL union
|
||||
]
|
||||
|
||||
for attempt in bypass_attempts:
|
||||
with pytest.raises(ValidationError):
|
||||
validate_domain(attempt)
|
||||
|
||||
|
||||
# Run tests with: pytest -v test_domain_validator.py
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
634
transcript-fixer/scripts/tests/test_error_recovery.py
Normal file
634
transcript-fixer/scripts/tests/test_error_recovery.py
Normal file
@@ -0,0 +1,634 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Error Recovery Testing Module
|
||||
|
||||
CRITICAL FIX (P1-10): Comprehensive error recovery testing
|
||||
|
||||
This module tests the system's ability to recover from various failure scenarios:
|
||||
- Database failures and transaction rollbacks
|
||||
- Network failures and retries
|
||||
- File system errors
|
||||
- Concurrent access conflicts
|
||||
- Resource exhaustion
|
||||
- Timeout handling
|
||||
- Data corruption
|
||||
|
||||
Author: Chief Engineer (ISTJ, 20 years experience)
|
||||
Date: 2025-10-29
|
||||
Priority: P1 - High
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import pytest
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Add parent directory to path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from core.connection_pool import ConnectionPool, PoolExhaustedError
|
||||
from core.correction_repository import CorrectionRepository, DatabaseError
|
||||
from utils.retry_logic import retry_sync, retry_async, RetryConfig, is_transient_error
|
||||
from utils.concurrency_manager import (
|
||||
ConcurrencyManager,
|
||||
ConcurrencyConfig,
|
||||
BackpressureError,
|
||||
CircuitBreakerOpenError
|
||||
)
|
||||
from utils.rate_limiter import RateLimiter, RateLimitConfig, RateLimitExceeded
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ==================== Test Fixtures ====================
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path():
|
||||
"""Create temporary database for testing"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
db_path = Path(tmp_dir) / "test.db"
|
||||
yield db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_pool(temp_db_path):
|
||||
"""Create connection pool for testing"""
|
||||
pool = ConnectionPool(temp_db_path, max_connections=3, pool_timeout=2.0)
|
||||
yield pool
|
||||
pool.close_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def correction_repository(temp_db_path):
|
||||
"""Create correction repository for testing"""
|
||||
repo = CorrectionRepository(temp_db_path, max_connections=3)
|
||||
yield repo
|
||||
# Cleanup handled by temp_db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def concurrency_manager():
|
||||
"""Create concurrency manager for testing"""
|
||||
config = ConcurrencyConfig(
|
||||
max_concurrent=3,
|
||||
max_queue_size=5,
|
||||
enable_circuit_breaker=True,
|
||||
circuit_failure_threshold=3
|
||||
)
|
||||
return ConcurrencyManager(config)
|
||||
|
||||
|
||||
# ==================== Database Error Recovery Tests ====================
|
||||
|
||||
class TestDatabaseErrorRecovery:
|
||||
"""Test database error recovery mechanisms"""
|
||||
|
||||
def test_transaction_rollback_on_error(self, correction_repository):
|
||||
"""
|
||||
Test that database transactions are rolled back on error.
|
||||
|
||||
Scenario: Try to insert correction with invalid confidence value.
|
||||
Expected: Error is raised, no data is modified.
|
||||
"""
|
||||
# Add a correction successfully
|
||||
correction_repository.add_correction(
|
||||
from_text="test1",
|
||||
to_text="corrected1",
|
||||
domain="general",
|
||||
source="manual",
|
||||
confidence=0.9
|
||||
)
|
||||
|
||||
# Verify it was added
|
||||
corrections = correction_repository.get_all_corrections(domain="general")
|
||||
initial_count = len(corrections)
|
||||
assert initial_count >= 1
|
||||
|
||||
# Try to add correction with invalid confidence (should fail)
|
||||
from utils.domain_validator import ValidationError
|
||||
with pytest.raises((ValidationError, DatabaseError)):
|
||||
correction_repository.add_correction(
|
||||
from_text="test_invalid",
|
||||
to_text="corrected",
|
||||
domain="general",
|
||||
source="manual",
|
||||
confidence=1.5 # Invalid: must be 0.0-1.0
|
||||
)
|
||||
|
||||
# Verify no new corrections were added
|
||||
corrections = correction_repository.get_all_corrections(domain="general")
|
||||
assert len(corrections) == initial_count
|
||||
|
||||
def test_connection_pool_recovery_from_exhaustion(self, connection_pool):
|
||||
"""
|
||||
Test that connection pool recovers after exhaustion.
|
||||
|
||||
Scenario: Exhaust all connections, then release them.
|
||||
Expected: Pool should become available again.
|
||||
"""
|
||||
connections = []
|
||||
|
||||
# Acquire all connections using context managers properly
|
||||
for i in range(3):
|
||||
ctx = connection_pool.get_connection()
|
||||
conn = ctx.__enter__()
|
||||
connections.append((ctx, conn))
|
||||
|
||||
# Try to acquire one more (should timeout with pool_timeout=2.0)
|
||||
with pytest.raises((PoolExhaustedError, TimeoutError)):
|
||||
with connection_pool.get_connection():
|
||||
pass
|
||||
|
||||
# Release all connections properly
|
||||
for ctx, conn in connections:
|
||||
try:
|
||||
ctx.__exit__(None, None, None)
|
||||
except:
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
# Should be able to acquire connection again
|
||||
with connection_pool.get_connection() as conn:
|
||||
assert conn is not None
|
||||
|
||||
def test_database_recovery_from_corruption(self, temp_db_path):
|
||||
"""
|
||||
Test that system handles corrupted database gracefully.
|
||||
|
||||
Scenario: Create corrupted database file.
|
||||
Expected: System should detect corruption and handle it.
|
||||
"""
|
||||
# Create a corrupted database file
|
||||
with open(temp_db_path, 'wb') as f:
|
||||
f.write(b'This is not a valid SQLite database')
|
||||
|
||||
# Try to create repository (should fail gracefully)
|
||||
with pytest.raises((sqlite3.DatabaseError, DatabaseError, FileNotFoundError)):
|
||||
repo = CorrectionRepository(temp_db_path)
|
||||
repo.get_all_corrections()
|
||||
|
||||
def test_concurrent_write_conflict_recovery(self, temp_db_path):
|
||||
"""
|
||||
Test recovery from concurrent write conflicts.
|
||||
|
||||
Scenario: Multiple threads try to write to same record.
|
||||
Expected: First write succeeds, subsequent ones update (UPSERT behavior).
|
||||
|
||||
Note: Each thread needs its own CorrectionRepository instance
|
||||
due to SQLite's thread-safety limitations.
|
||||
"""
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def write_correction(thread_id, db_path):
|
||||
try:
|
||||
# Each thread creates its own repository
|
||||
from core.correction_repository import CorrectionRepository
|
||||
thread_repo = CorrectionRepository(db_path, max_connections=1)
|
||||
|
||||
thread_repo.add_correction(
|
||||
from_text="concurrent_test",
|
||||
to_text=f"corrected_{thread_id}",
|
||||
domain="general",
|
||||
source="manual"
|
||||
)
|
||||
results.append(thread_id)
|
||||
except Exception as e:
|
||||
errors.append((thread_id, str(e)))
|
||||
|
||||
# Start multiple threads
|
||||
threads = [threading.Thread(target=write_correction, args=(i, temp_db_path)) for i in range(5)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Due to UPSERT behavior, all should succeed (they update the same record)
|
||||
assert len(results) + len(errors) == 5
|
||||
|
||||
# Verify database is still consistent
|
||||
verify_repo = CorrectionRepository(temp_db_path)
|
||||
corrections = verify_repo.get_all_corrections()
|
||||
assert any(c.from_text == "concurrent_test" for c in corrections)
|
||||
|
||||
# Should only have one record (UNIQUE constraint + UPSERT)
|
||||
concurrent_corrections = [c for c in corrections if c.from_text == "concurrent_test"]
|
||||
assert len(concurrent_corrections) == 1
|
||||
|
||||
|
||||
# ==================== Network Error Recovery Tests ====================
|
||||
|
||||
class TestNetworkErrorRecovery:
|
||||
"""Test network error recovery mechanisms"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_on_transient_network_error(self):
|
||||
"""
|
||||
Test that transient network errors trigger retry.
|
||||
|
||||
Scenario: API call fails with timeout, then succeeds on retry.
|
||||
Expected: Operation succeeds after retry.
|
||||
"""
|
||||
attempt_count = [0]
|
||||
|
||||
@retry_async(RetryConfig(max_attempts=3, base_delay=0.1))
|
||||
async def flaky_network_call():
|
||||
attempt_count[0] += 1
|
||||
if attempt_count[0] < 3:
|
||||
import httpx
|
||||
raise httpx.ConnectTimeout("Connection timeout")
|
||||
return "success"
|
||||
|
||||
result = await flaky_network_call()
|
||||
assert result == "success"
|
||||
assert attempt_count[0] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_retry_on_permanent_error(self):
|
||||
"""
|
||||
Test that permanent errors are not retried.
|
||||
|
||||
Scenario: API call fails with authentication error.
|
||||
Expected: Error is raised immediately without retry.
|
||||
"""
|
||||
attempt_count = [0]
|
||||
|
||||
@retry_async(RetryConfig(max_attempts=3, base_delay=0.1))
|
||||
async def auth_error_call():
|
||||
attempt_count[0] += 1
|
||||
raise ValueError("Invalid credentials") # Permanent error
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await auth_error_call()
|
||||
|
||||
# Should fail immediately without retry
|
||||
assert attempt_count[0] == 1
|
||||
|
||||
def test_transient_error_classification(self):
|
||||
"""
|
||||
Test correct classification of transient vs permanent errors.
|
||||
|
||||
Scenario: Various exception types.
|
||||
Expected: Correct classification for each type.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
# Transient errors
|
||||
assert is_transient_error(httpx.ConnectTimeout("timeout")) == True
|
||||
assert is_transient_error(httpx.ReadTimeout("timeout")) == True
|
||||
assert is_transient_error(httpx.ConnectError("connection failed")) == True
|
||||
|
||||
# Permanent errors
|
||||
assert is_transient_error(ValueError("invalid input")) == False
|
||||
assert is_transient_error(KeyError("not found")) == False
|
||||
|
||||
|
||||
# ==================== Concurrency Error Recovery Tests ====================
|
||||
|
||||
class TestConcurrencyErrorRecovery:
|
||||
"""Test concurrent operation error recovery"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_opens_after_failures(self, concurrency_manager):
|
||||
"""
|
||||
Test that circuit breaker opens after threshold failures.
|
||||
|
||||
Scenario: Multiple consecutive failures.
|
||||
Expected: Circuit opens, subsequent requests rejected.
|
||||
"""
|
||||
# Cause 3 failures (threshold)
|
||||
for i in range(3):
|
||||
try:
|
||||
async with concurrency_manager.acquire():
|
||||
raise Exception("Simulated failure")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Circuit should be OPEN now
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
async with concurrency_manager.acquire():
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_recovery(self, concurrency_manager):
|
||||
"""
|
||||
Test that circuit breaker can recover after timeout.
|
||||
|
||||
Scenario: Circuit opens, then recovery timeout elapses, then success.
|
||||
Expected: Circuit transitions OPEN → HALF_OPEN → CLOSED.
|
||||
"""
|
||||
# Configure short recovery timeout for testing
|
||||
concurrency_manager.config.circuit_recovery_timeout = 0.5
|
||||
|
||||
# Cause failures to open circuit
|
||||
for i in range(3):
|
||||
try:
|
||||
async with concurrency_manager.acquire():
|
||||
raise Exception("Failure")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Circuit should be OPEN
|
||||
metrics = concurrency_manager.get_metrics()
|
||||
assert metrics.circuit_state.value == "open"
|
||||
|
||||
# Wait for recovery timeout
|
||||
await asyncio.sleep(0.6)
|
||||
|
||||
# Try a successful operation (should transition to HALF_OPEN then CLOSED)
|
||||
async with concurrency_manager.acquire():
|
||||
pass # Success
|
||||
|
||||
# One more success to fully close
|
||||
async with concurrency_manager.acquire():
|
||||
pass
|
||||
|
||||
# Circuit should be CLOSED
|
||||
metrics = concurrency_manager.get_metrics()
|
||||
assert metrics.circuit_state.value in ("closed", "half_open")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backpressure_handling(self):
|
||||
"""
|
||||
Test that backpressure prevents system overload.
|
||||
|
||||
Scenario: Queue fills up beyond max_queue_size.
|
||||
Expected: Additional requests are rejected with BackpressureError.
|
||||
"""
|
||||
# Create manager with small limits for testing
|
||||
config = ConcurrencyConfig(
|
||||
max_concurrent=1,
|
||||
max_queue_size=2,
|
||||
enable_backpressure=True
|
||||
)
|
||||
manager = ConcurrencyManager(config)
|
||||
|
||||
async def slow_task():
|
||||
async with manager.acquire():
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Start tasks that will fill queue
|
||||
tasks = []
|
||||
rejected_count = 0
|
||||
|
||||
for i in range(6): # Try to start 6 tasks (more than queue can hold)
|
||||
try:
|
||||
task = asyncio.create_task(slow_task())
|
||||
tasks.append(task)
|
||||
await asyncio.sleep(0.01) # Small delay between starts
|
||||
except BackpressureError:
|
||||
rejected_count += 1
|
||||
|
||||
# Wait a bit then cancel remaining tasks
|
||||
await asyncio.sleep(0.1)
|
||||
for task in tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
# Gather results (ignore cancellation errors)
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check metrics
|
||||
metrics = manager.get_metrics()
|
||||
|
||||
# Either direct BackpressureError or rejected in metrics
|
||||
assert rejected_count > 0 or metrics.rejected_requests > 0
|
||||
|
||||
|
||||
# ==================== Resource Error Recovery Tests ====================
|
||||
|
||||
class TestResourceErrorRecovery:
|
||||
"""Test resource error recovery mechanisms"""
|
||||
|
||||
def test_rate_limiter_recovery_after_limit_reached(self):
|
||||
"""
|
||||
Test that rate limiter allows requests after window resets.
|
||||
|
||||
Scenario: Exhaust rate limit, wait for window reset.
|
||||
Expected: New requests are allowed after reset.
|
||||
"""
|
||||
config = RateLimitConfig(
|
||||
max_requests=3,
|
||||
window_seconds=0.5, # Short window for testing
|
||||
)
|
||||
limiter = RateLimiter(config)
|
||||
|
||||
# Exhaust limit
|
||||
for i in range(3):
|
||||
assert limiter.acquire(blocking=False) == True
|
||||
|
||||
# Should be exhausted
|
||||
assert limiter.acquire(blocking=False) == False
|
||||
|
||||
# Wait for window reset
|
||||
time.sleep(0.6)
|
||||
|
||||
# Should be available again
|
||||
assert limiter.acquire(blocking=False) == True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_recovery(self, concurrency_manager):
|
||||
"""
|
||||
Test that timeouts are handled gracefully.
|
||||
|
||||
Scenario: Operation exceeds timeout.
|
||||
Expected: Operation is cancelled, resources released.
|
||||
"""
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
async with concurrency_manager.acquire(timeout=0.1):
|
||||
await asyncio.sleep(1.0) # Exceeds timeout
|
||||
|
||||
# Verify metrics were updated
|
||||
metrics = concurrency_manager.get_metrics()
|
||||
assert metrics.timeout_requests > 0
|
||||
|
||||
def test_file_lock_recovery_after_timeout(self, temp_db_path):
|
||||
"""
|
||||
Test recovery from file lock timeouts.
|
||||
|
||||
Scenario: Lock held too long, timeout occurs.
|
||||
Expected: Lock is released, subsequent operations succeed.
|
||||
"""
|
||||
from filelock import FileLock, Timeout as FileLockTimeout
|
||||
|
||||
lock_path = temp_db_path.parent / "test.lock"
|
||||
lock = FileLock(str(lock_path), timeout=0.5)
|
||||
|
||||
# Acquire lock
|
||||
with lock.acquire():
|
||||
# Try to acquire again (should timeout)
|
||||
lock2 = FileLock(str(lock_path), timeout=0.2)
|
||||
with pytest.raises(FileLockTimeout):
|
||||
with lock2.acquire():
|
||||
pass
|
||||
|
||||
# Lock should be released, can acquire now
|
||||
with lock.acquire():
|
||||
pass # Success
|
||||
|
||||
|
||||
# ==================== Data Corruption Recovery Tests ====================
|
||||
|
||||
class TestDataCorruptionRecovery:
|
||||
"""Test data corruption detection and recovery"""
|
||||
|
||||
def test_invalid_data_detection(self, correction_repository):
|
||||
"""
|
||||
Test that invalid data is detected and rejected.
|
||||
|
||||
Scenario: Attempt to insert invalid data.
|
||||
Expected: Validation error, database remains consistent.
|
||||
"""
|
||||
# Try to insert correction with invalid confidence
|
||||
with pytest.raises(DatabaseError):
|
||||
correction_repository.add_correction(
|
||||
from_text="test",
|
||||
to_text="corrected",
|
||||
domain="general",
|
||||
source="manual",
|
||||
confidence=1.5 # Invalid (must be 0.0-1.0)
|
||||
)
|
||||
|
||||
# Verify database is still consistent
|
||||
corrections = correction_repository.get_all_corrections()
|
||||
assert all(0.0 <= c.confidence <= 1.0 for c in corrections)
|
||||
|
||||
def test_encoding_error_recovery(self):
|
||||
"""
|
||||
Test recovery from encoding errors.
|
||||
|
||||
Scenario: Process text with invalid encoding.
|
||||
Expected: Error is handled, processing continues.
|
||||
"""
|
||||
from core.change_extractor import ChangeExtractor, InputValidationError
|
||||
|
||||
extractor = ChangeExtractor()
|
||||
|
||||
# Test with invalid UTF-8 sequences
|
||||
invalid_text = b'\x80\x81\x82'.decode('utf-8', errors='replace')
|
||||
|
||||
try:
|
||||
# Should handle gracefully or raise specific error
|
||||
changes = extractor.extract_changes(invalid_text, "corrected")
|
||||
except InputValidationError as e:
|
||||
# Expected - validation caught the issue
|
||||
assert "UTF-8" in str(e) or "encoding" in str(e).lower()
|
||||
|
||||
|
||||
# ==================== Integration Error Recovery Tests ====================
|
||||
|
||||
class TestIntegrationErrorRecovery:
|
||||
"""Test end-to-end error recovery scenarios"""
|
||||
|
||||
def test_full_system_recovery_from_multiple_failures(
|
||||
self, correction_repository, concurrency_manager
|
||||
):
|
||||
"""
|
||||
Test that system recovers from multiple simultaneous failures.
|
||||
|
||||
Scenario: Database error + rate limit + concurrency limit.
|
||||
Expected: System degrades gracefully, recovers when possible.
|
||||
"""
|
||||
# Record initial state
|
||||
initial_corrections = len(correction_repository.get_all_corrections())
|
||||
|
||||
# Simulate various failures
|
||||
failures = []
|
||||
|
||||
# 1. Try to add duplicate correction (database error)
|
||||
correction_repository.add_correction(
|
||||
from_text="multi_fail_test",
|
||||
to_text="original",
|
||||
domain="general",
|
||||
source="manual"
|
||||
)
|
||||
|
||||
try:
|
||||
correction_repository.add_correction(
|
||||
from_text="multi_fail_test", # Duplicate
|
||||
to_text="duplicate",
|
||||
domain="general",
|
||||
source="manual"
|
||||
)
|
||||
except DatabaseError:
|
||||
failures.append("database")
|
||||
|
||||
# 2. Simulate concurrency failure
|
||||
async def test_concurrency():
|
||||
try:
|
||||
# Cause circuit breaker to open
|
||||
for i in range(3):
|
||||
try:
|
||||
async with concurrency_manager.acquire():
|
||||
raise Exception("Failure")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Circuit should be open
|
||||
with pytest.raises(CircuitBreakerOpenError):
|
||||
async with concurrency_manager.acquire():
|
||||
pass
|
||||
failures.append("concurrency")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
asyncio.run(test_concurrency())
|
||||
|
||||
# Verify system is still operational
|
||||
corrections = correction_repository.get_all_corrections()
|
||||
assert len(corrections) == initial_corrections + 1
|
||||
|
||||
# Verify metrics were recorded
|
||||
metrics = concurrency_manager.get_metrics()
|
||||
assert metrics.failed_requests > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cascading_failure_prevention(self):
|
||||
"""
|
||||
Test that failures don't cascade through the system.
|
||||
|
||||
Scenario: One component fails, others continue working.
|
||||
Expected: Failure is isolated, system remains operational.
|
||||
"""
|
||||
# This test verifies isolation between components
|
||||
config = ConcurrencyConfig(
|
||||
max_concurrent=2,
|
||||
enable_circuit_breaker=True,
|
||||
circuit_failure_threshold=3
|
||||
)
|
||||
manager1 = ConcurrencyManager(config)
|
||||
manager2 = ConcurrencyManager(config)
|
||||
|
||||
# Cause failures in manager1
|
||||
for i in range(3):
|
||||
try:
|
||||
async with manager1.acquire():
|
||||
raise Exception("Failure")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# manager1 circuit should be open
|
||||
metrics1 = manager1.get_metrics()
|
||||
assert metrics1.circuit_state.value == "open"
|
||||
|
||||
# manager2 should still work
|
||||
async with manager2.acquire():
|
||||
pass # Success
|
||||
|
||||
metrics2 = manager2.get_metrics()
|
||||
assert metrics2.circuit_state.value == "closed"
|
||||
|
||||
|
||||
# ==================== Test Runner ====================
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests with pytest
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
464
transcript-fixer/scripts/tests/test_learning_engine.py
Normal file
464
transcript-fixer/scripts/tests/test_learning_engine.py
Normal file
@@ -0,0 +1,464 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test suite for LearningEngine thread-safety.
|
||||
|
||||
CRITICAL FIX (P1-1): Tests for race condition prevention
|
||||
- Concurrent writes to pending suggestions
|
||||
- Concurrent writes to rejected patterns
|
||||
- Concurrent writes to auto-approved patterns
|
||||
- Lock acquisition and release
|
||||
- Deadlock prevention
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from dataclasses import asdict
|
||||
|
||||
import pytest
|
||||
|
||||
# Import classes - note: run tests from scripts/ directory
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
# Import only what we need to avoid circular dependencies
|
||||
from dataclasses import dataclass, asdict as dataclass_asdict
|
||||
|
||||
# Manually define Suggestion to avoid circular import
|
||||
@dataclass
|
||||
class Suggestion:
|
||||
"""Represents a learned correction suggestion"""
|
||||
from_text: str
|
||||
to_text: str
|
||||
frequency: int
|
||||
confidence: float
|
||||
examples: List
|
||||
first_seen: str
|
||||
last_seen: str
|
||||
status: str
|
||||
|
||||
# Import LearningEngine last
|
||||
# We'll mock the correction_service dependency to avoid circular imports
|
||||
import core.learning_engine as le_module
|
||||
LearningEngine = le_module.LearningEngine
|
||||
|
||||
|
||||
class TestLearningEngineThreadSafety:
|
||||
"""Test thread-safety of LearningEngine file operations"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dirs(self):
|
||||
"""Create temporary directories for testing"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
history_dir = temp_path / "history"
|
||||
learned_dir = temp_path / "learned"
|
||||
history_dir.mkdir()
|
||||
learned_dir.mkdir()
|
||||
yield history_dir, learned_dir
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, temp_dirs):
|
||||
"""Create LearningEngine instance"""
|
||||
history_dir, learned_dir = temp_dirs
|
||||
return LearningEngine(history_dir, learned_dir)
|
||||
|
||||
def test_concurrent_save_pending_no_data_loss(self, engine):
|
||||
"""
|
||||
Test that concurrent writes to pending suggestions don't lose data.
|
||||
|
||||
CRITICAL: This is the main race condition we're preventing.
|
||||
Without locks, concurrent appends would overwrite each other.
|
||||
"""
|
||||
num_threads = 10
|
||||
suggestions_per_thread = 5
|
||||
|
||||
def save_suggestions(thread_id: int):
|
||||
"""Save suggestions from a single thread"""
|
||||
suggestions = []
|
||||
for i in range(suggestions_per_thread):
|
||||
suggestions.append(Suggestion(
|
||||
from_text=f"thread{thread_id}_from{i}",
|
||||
to_text=f"thread{thread_id}_to{i}",
|
||||
frequency=1,
|
||||
confidence=0.9,
|
||||
examples=[],
|
||||
first_seen="2025-01-01",
|
||||
last_seen="2025-01-01",
|
||||
status="pending"
|
||||
))
|
||||
engine._save_pending_suggestions(suggestions)
|
||||
|
||||
# Launch concurrent threads
|
||||
threads = []
|
||||
for thread_id in range(num_threads):
|
||||
thread = threading.Thread(target=save_suggestions, args=(thread_id,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify: ALL suggestions should be saved
|
||||
pending = engine._load_pending_suggestions()
|
||||
expected_count = num_threads * suggestions_per_thread
|
||||
|
||||
assert len(pending) == expected_count, (
|
||||
f"Data loss detected! Expected {expected_count} suggestions, "
|
||||
f"but found {len(pending)}. Race condition occurred."
|
||||
)
|
||||
|
||||
# Verify uniqueness (no duplicates from overwrites)
|
||||
from_texts = [s["from_text"] for s in pending]
|
||||
assert len(from_texts) == len(set(from_texts)), "Duplicate suggestions found"
|
||||
|
||||
def test_concurrent_approve_suggestions(self, engine):
|
||||
"""Test that concurrent approvals don't cause race conditions"""
|
||||
# Pre-populate with suggestions
|
||||
initial_suggestions = []
|
||||
for i in range(20):
|
||||
initial_suggestions.append(Suggestion(
|
||||
from_text=f"from{i}",
|
||||
to_text=f"to{i}",
|
||||
frequency=1,
|
||||
confidence=0.9,
|
||||
examples=[],
|
||||
first_seen="2025-01-01",
|
||||
last_seen="2025-01-01",
|
||||
status="pending"
|
||||
))
|
||||
engine._save_pending_suggestions(initial_suggestions)
|
||||
|
||||
# Approve half of them concurrently
|
||||
def approve_suggestion(from_text: str):
|
||||
engine.approve_suggestion(from_text)
|
||||
|
||||
threads = []
|
||||
for i in range(10):
|
||||
thread = threading.Thread(target=approve_suggestion, args=(f"from{i}",))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify: exactly 10 should remain
|
||||
pending = engine._load_pending_suggestions()
|
||||
assert len(pending) == 10, f"Expected 10 remaining, found {len(pending)}"
|
||||
|
||||
# Verify: the correct ones remain
|
||||
remaining_from_texts = {s["from_text"] for s in pending}
|
||||
expected_remaining = {f"from{i}" for i in range(10, 20)}
|
||||
assert remaining_from_texts == expected_remaining
|
||||
|
||||
def test_concurrent_reject_suggestions(self, engine):
|
||||
"""Test that concurrent rejections handle both pending and rejected locks"""
|
||||
# Pre-populate with suggestions
|
||||
initial_suggestions = []
|
||||
for i in range(10):
|
||||
initial_suggestions.append(Suggestion(
|
||||
from_text=f"from{i}",
|
||||
to_text=f"to{i}",
|
||||
frequency=1,
|
||||
confidence=0.9,
|
||||
examples=[],
|
||||
first_seen="2025-01-01",
|
||||
last_seen="2025-01-01",
|
||||
status="pending"
|
||||
))
|
||||
engine._save_pending_suggestions(initial_suggestions)
|
||||
|
||||
# Reject all of them concurrently
|
||||
def reject_suggestion(from_text: str, to_text: str):
|
||||
engine.reject_suggestion(from_text, to_text)
|
||||
|
||||
threads = []
|
||||
for i in range(10):
|
||||
thread = threading.Thread(
|
||||
target=reject_suggestion,
|
||||
args=(f"from{i}", f"to{i}")
|
||||
)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify: pending should be empty
|
||||
pending = engine._load_pending_suggestions()
|
||||
assert len(pending) == 0, f"Expected 0 pending, found {len(pending)}"
|
||||
|
||||
# Verify: rejected should have all 10
|
||||
rejected = engine._load_rejected()
|
||||
assert len(rejected) == 10, f"Expected 10 rejected, found {len(rejected)}"
|
||||
|
||||
expected_rejected = {(f"from{i}", f"to{i}") for i in range(10)}
|
||||
assert rejected == expected_rejected
|
||||
|
||||
def test_concurrent_auto_approve_no_data_loss(self, engine):
|
||||
"""Test that concurrent auto-approvals don't lose data"""
|
||||
num_threads = 5
|
||||
patterns_per_thread = 3
|
||||
|
||||
def save_auto_approved(thread_id: int):
|
||||
"""Save auto-approved patterns from a single thread"""
|
||||
patterns = []
|
||||
for i in range(patterns_per_thread):
|
||||
patterns.append({
|
||||
"from": f"thread{thread_id}_from{i}",
|
||||
"to": f"thread{thread_id}_to{i}",
|
||||
"frequency": 5,
|
||||
"confidence": 0.9,
|
||||
"domain": "general"
|
||||
})
|
||||
engine._save_auto_approved(patterns)
|
||||
|
||||
# Launch concurrent threads
|
||||
threads = []
|
||||
for thread_id in range(num_threads):
|
||||
thread = threading.Thread(target=save_auto_approved, args=(thread_id,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify: ALL patterns should be saved
|
||||
with open(engine.auto_approved_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
auto_approved = data.get("auto_approved", [])
|
||||
|
||||
expected_count = num_threads * patterns_per_thread
|
||||
assert len(auto_approved) == expected_count, (
|
||||
f"Data loss in auto-approved! Expected {expected_count}, "
|
||||
f"found {len(auto_approved)}"
|
||||
)
|
||||
|
||||
def test_lock_timeout_handling(self, engine):
|
||||
"""Test that lock timeout is handled gracefully"""
|
||||
# Acquire lock and hold it
|
||||
lock_acquired = threading.Event()
|
||||
lock_released = threading.Event()
|
||||
|
||||
def hold_lock():
|
||||
"""Hold lock for extended period"""
|
||||
with engine._file_lock(engine.pending_lock, "hold lock"):
|
||||
lock_acquired.set()
|
||||
# Hold lock for 2 seconds
|
||||
lock_released.wait(timeout=2.0)
|
||||
|
||||
# Start thread holding lock
|
||||
holder_thread = threading.Thread(target=hold_lock)
|
||||
holder_thread.start()
|
||||
|
||||
# Wait for lock to be acquired
|
||||
lock_acquired.wait(timeout=1.0)
|
||||
|
||||
# Try to acquire lock with short timeout (should fail)
|
||||
original_timeout = engine.lock_timeout
|
||||
engine.lock_timeout = 0.5 # 500ms timeout
|
||||
|
||||
try:
|
||||
with pytest.raises(RuntimeError, match="File lock timeout"):
|
||||
with engine._file_lock(engine.pending_lock, "test timeout"):
|
||||
pass
|
||||
finally:
|
||||
# Restore original timeout
|
||||
engine.lock_timeout = original_timeout
|
||||
# Release the held lock
|
||||
lock_released.set()
|
||||
holder_thread.join()
|
||||
|
||||
def test_no_deadlock_with_multiple_locks(self, engine):
|
||||
"""Test that acquiring multiple locks doesn't cause deadlock"""
|
||||
num_threads = 5
|
||||
iterations = 10
|
||||
|
||||
def reject_multiple():
|
||||
"""Reject multiple suggestions (acquires both pending and rejected locks)"""
|
||||
for i in range(iterations):
|
||||
# This exercises the lock acquisition order
|
||||
engine.reject_suggestion(f"from{i}", f"to{i}")
|
||||
|
||||
# Pre-populate
|
||||
for i in range(iterations):
|
||||
engine._save_pending_suggestions([Suggestion(
|
||||
from_text=f"from{i}",
|
||||
to_text=f"to{i}",
|
||||
frequency=1,
|
||||
confidence=0.9,
|
||||
examples=[],
|
||||
first_seen="2025-01-01",
|
||||
last_seen="2025-01-01",
|
||||
status="pending"
|
||||
)])
|
||||
|
||||
# Launch concurrent rejections
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
thread = threading.Thread(target=reject_multiple)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for completion (with timeout to detect deadlock)
|
||||
deadline = time.time() + 10.0 # 10 second deadline
|
||||
for thread in threads:
|
||||
remaining = deadline - time.time()
|
||||
if remaining <= 0:
|
||||
pytest.fail("Deadlock detected! Threads did not complete in time.")
|
||||
thread.join(timeout=remaining)
|
||||
if thread.is_alive():
|
||||
pytest.fail("Deadlock detected! Thread still alive after timeout.")
|
||||
|
||||
# If we get here, no deadlock occurred
|
||||
assert True
|
||||
|
||||
def test_lock_files_created(self, engine):
|
||||
"""Test that lock files are created in correct location"""
|
||||
# Trigger an operation that uses locks
|
||||
suggestions = [Suggestion(
|
||||
from_text="test",
|
||||
to_text="test",
|
||||
frequency=1,
|
||||
confidence=0.9,
|
||||
examples=[],
|
||||
first_seen="2025-01-01",
|
||||
last_seen="2025-01-01",
|
||||
status="pending"
|
||||
)]
|
||||
engine._save_pending_suggestions(suggestions)
|
||||
|
||||
# Lock files should exist (they're created by filelock)
|
||||
# Note: filelock may clean up lock files after release
|
||||
# So we just verify the paths are correctly configured
|
||||
assert engine.pending_lock.name == ".pending_review.lock"
|
||||
assert engine.rejected_lock.name == ".rejected.lock"
|
||||
assert engine.auto_approved_lock.name == ".auto_approved.lock"
|
||||
|
||||
def test_directory_creation_under_lock(self, engine):
|
||||
"""Test that directory creation is safe under lock"""
|
||||
# Remove learned directory
|
||||
import shutil
|
||||
if engine.learned_dir.exists():
|
||||
shutil.rmtree(engine.learned_dir)
|
||||
|
||||
# Recreate it concurrently (parent.mkdir in save methods)
|
||||
def save_concurrent():
|
||||
suggestions = [Suggestion(
|
||||
from_text="test",
|
||||
to_text="test",
|
||||
frequency=1,
|
||||
confidence=0.9,
|
||||
examples=[],
|
||||
first_seen="2025-01-01",
|
||||
last_seen="2025-01-01",
|
||||
status="pending"
|
||||
)]
|
||||
engine._save_pending_suggestions(suggestions)
|
||||
|
||||
threads = []
|
||||
for _ in range(5):
|
||||
thread = threading.Thread(target=save_concurrent)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Directory should exist and contain data
|
||||
assert engine.learned_dir.exists()
|
||||
assert engine.pending_file.exists()
|
||||
|
||||
|
||||
class TestLearningEngineCorrectness:
|
||||
"""Test that file locking doesn't break functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dirs(self):
|
||||
"""Create temporary directories for testing"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
history_dir = temp_path / "history"
|
||||
learned_dir = temp_path / "learned"
|
||||
history_dir.mkdir()
|
||||
learned_dir.mkdir()
|
||||
yield history_dir, learned_dir
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, temp_dirs):
|
||||
"""Create LearningEngine instance"""
|
||||
history_dir, learned_dir = temp_dirs
|
||||
return LearningEngine(history_dir, learned_dir)
|
||||
|
||||
def test_save_and_load_pending(self, engine):
|
||||
"""Test basic save and load functionality"""
|
||||
suggestions = [Suggestion(
|
||||
from_text="hello",
|
||||
to_text="你好",
|
||||
frequency=5,
|
||||
confidence=0.95,
|
||||
examples=[{"file": "test.md", "line": 1, "context": "test", "timestamp": "2025-01-01"}],
|
||||
first_seen="2025-01-01",
|
||||
last_seen="2025-01-02",
|
||||
status="pending"
|
||||
)]
|
||||
|
||||
engine._save_pending_suggestions(suggestions)
|
||||
loaded = engine._load_pending_suggestions()
|
||||
|
||||
assert len(loaded) == 1
|
||||
assert loaded[0]["from_text"] == "hello"
|
||||
assert loaded[0]["to_text"] == "你好"
|
||||
assert loaded[0]["confidence"] == 0.95
|
||||
|
||||
def test_approve_removes_from_pending(self, engine):
|
||||
"""Test that approval removes suggestion from pending"""
|
||||
suggestions = [Suggestion(
|
||||
from_text="test",
|
||||
to_text="测试",
|
||||
frequency=3,
|
||||
confidence=0.9,
|
||||
examples=[],
|
||||
first_seen="2025-01-01",
|
||||
last_seen="2025-01-01",
|
||||
status="pending"
|
||||
)]
|
||||
|
||||
engine._save_pending_suggestions(suggestions)
|
||||
assert len(engine._load_pending_suggestions()) == 1
|
||||
|
||||
result = engine.approve_suggestion("test")
|
||||
assert result is True
|
||||
assert len(engine._load_pending_suggestions()) == 0
|
||||
|
||||
def test_reject_moves_to_rejected(self, engine):
|
||||
"""Test that rejection moves suggestion to rejected list"""
|
||||
suggestions = [Suggestion(
|
||||
from_text="bad",
|
||||
to_text="wrong",
|
||||
frequency=1,
|
||||
confidence=0.8,
|
||||
examples=[],
|
||||
first_seen="2025-01-01",
|
||||
last_seen="2025-01-01",
|
||||
status="pending"
|
||||
)]
|
||||
|
||||
engine._save_pending_suggestions(suggestions)
|
||||
engine.reject_suggestion("bad", "wrong")
|
||||
|
||||
# Should be removed from pending
|
||||
pending = engine._load_pending_suggestions()
|
||||
assert len(pending) == 0
|
||||
|
||||
# Should be added to rejected
|
||||
rejected = engine._load_rejected()
|
||||
assert ("bad", "wrong") in rejected
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
436
transcript-fixer/scripts/tests/test_path_validator.py
Normal file
436
transcript-fixer/scripts/tests/test_path_validator.py
Normal file
@@ -0,0 +1,436 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Suite for Path Validator
|
||||
|
||||
CRITICAL FIX VERIFICATION: Tests for Critical-5
|
||||
Purpose: Verify path traversal and symlink attack prevention
|
||||
|
||||
Test Coverage:
|
||||
1. Path traversal prevention (../)
|
||||
2. Symlink attack detection
|
||||
3. Directory whitelist enforcement
|
||||
4. File extension validation
|
||||
5. Null byte injection prevention
|
||||
6. Path canonicalization
|
||||
|
||||
Author: Chief Engineer
|
||||
Priority: P0 - Critical
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from utils.path_validator import (
|
||||
PathValidator,
|
||||
PathValidationError,
|
||||
validate_input_path,
|
||||
validate_output_path,
|
||||
ALLOWED_READ_EXTENSIONS,
|
||||
ALLOWED_WRITE_EXTENSIONS,
|
||||
)
|
||||
|
||||
|
||||
class TestPathTraversalPrevention:
|
||||
"""Test path traversal attack prevention"""
|
||||
|
||||
def test_parent_directory_traversal(self, tmp_path):
|
||||
"""Test ../ path traversal is blocked"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
# Create a file outside allowed directory
|
||||
outside_dir = tmp_path.parent / "outside"
|
||||
outside_dir.mkdir(exist_ok=True)
|
||||
outside_file = outside_dir / "secret.md"
|
||||
outside_file.write_text("secret data")
|
||||
|
||||
# Try to access it via ../
|
||||
malicious_path = str(tmp_path / ".." / "outside" / "secret.md")
|
||||
|
||||
with pytest.raises(PathValidationError, match="Dangerous pattern"):
|
||||
validator.validate_input_path(malicious_path)
|
||||
|
||||
# Cleanup
|
||||
outside_file.unlink()
|
||||
outside_dir.rmdir()
|
||||
|
||||
def test_absolute_path_outside_whitelist(self, tmp_path):
|
||||
"""Test absolute paths outside whitelist are blocked"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
# Try to access /etc/passwd
|
||||
with pytest.raises(PathValidationError, match="not under allowed directories"):
|
||||
validator.validate_input_path("/etc/passwd")
|
||||
|
||||
def test_multiple_parent_traversals(self, tmp_path):
|
||||
"""Test ../../ is blocked"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
with pytest.raises(PathValidationError, match="Dangerous pattern"):
|
||||
validator.validate_input_path("../../etc/passwd")
|
||||
|
||||
|
||||
class TestSymlinkAttacks:
|
||||
"""Test symlink attack prevention"""
|
||||
|
||||
def test_direct_symlink_blocked(self, tmp_path):
|
||||
"""Test direct symlink is blocked by default"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
# Create a real file
|
||||
real_file = tmp_path / "real.md"
|
||||
real_file.write_text("data")
|
||||
|
||||
# Create symlink to it
|
||||
symlink = tmp_path / "link.md"
|
||||
symlink.symlink_to(real_file)
|
||||
|
||||
with pytest.raises(PathValidationError, match="Symlink detected"):
|
||||
validator.validate_input_path(str(symlink))
|
||||
|
||||
# Cleanup
|
||||
symlink.unlink()
|
||||
real_file.unlink()
|
||||
|
||||
def test_symlink_allowed_when_configured(self, tmp_path):
|
||||
"""Test symlinks can be allowed"""
|
||||
validator = PathValidator(
|
||||
allowed_base_dirs={tmp_path},
|
||||
allow_symlinks=True
|
||||
)
|
||||
|
||||
# Create real file and symlink
|
||||
real_file = tmp_path / "real.md"
|
||||
real_file.write_text("data")
|
||||
|
||||
symlink = tmp_path / "link.md"
|
||||
symlink.symlink_to(real_file)
|
||||
|
||||
# Should succeed with allow_symlinks=True
|
||||
result = validator.validate_input_path(str(symlink))
|
||||
assert result.exists()
|
||||
|
||||
# Cleanup
|
||||
symlink.unlink()
|
||||
real_file.unlink()
|
||||
|
||||
def test_symlink_in_parent_directory(self, tmp_path):
|
||||
"""Test symlink in parent path is blocked"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
# Create real directory
|
||||
real_dir = tmp_path / "real_dir"
|
||||
real_dir.mkdir()
|
||||
|
||||
# Create symlink to directory
|
||||
symlink_dir = tmp_path / "link_dir"
|
||||
symlink_dir.symlink_to(real_dir)
|
||||
|
||||
# Create file inside real directory
|
||||
real_file = real_dir / "file.md"
|
||||
real_file.write_text("data")
|
||||
|
||||
# Try to access via symlinked directory
|
||||
malicious_path = symlink_dir / "file.md"
|
||||
|
||||
with pytest.raises(PathValidationError, match="Symlink"):
|
||||
validator.validate_input_path(str(malicious_path))
|
||||
|
||||
# Cleanup
|
||||
real_file.unlink()
|
||||
symlink_dir.unlink()
|
||||
real_dir.rmdir()
|
||||
|
||||
|
||||
class TestDirectoryWhitelist:
|
||||
"""Test directory whitelist enforcement"""
|
||||
|
||||
def test_file_in_allowed_directory(self, tmp_path):
|
||||
"""Test file in allowed directory is accepted"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
test_file = tmp_path / "test.md"
|
||||
test_file.write_text("test data")
|
||||
|
||||
result = validator.validate_input_path(str(test_file))
|
||||
assert result == test_file.resolve()
|
||||
|
||||
test_file.unlink()
|
||||
|
||||
def test_file_outside_allowed_directory(self, tmp_path):
|
||||
"""Test file outside allowed directory is rejected"""
|
||||
allowed_dir = tmp_path / "allowed"
|
||||
allowed_dir.mkdir()
|
||||
|
||||
validator = PathValidator(allowed_base_dirs={allowed_dir})
|
||||
|
||||
# File in parent directory (not in whitelist)
|
||||
outside_file = tmp_path / "outside.md"
|
||||
outside_file.write_text("data")
|
||||
|
||||
with pytest.raises(PathValidationError, match="not under allowed directories"):
|
||||
validator.validate_input_path(str(outside_file))
|
||||
|
||||
outside_file.unlink()
|
||||
|
||||
def test_add_allowed_directory(self, tmp_path):
|
||||
"""Test dynamically adding allowed directories"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path / "initial"})
|
||||
|
||||
new_dir = tmp_path / "new"
|
||||
new_dir.mkdir()
|
||||
|
||||
# Should fail initially
|
||||
test_file = new_dir / "test.md"
|
||||
test_file.write_text("data")
|
||||
|
||||
with pytest.raises(PathValidationError):
|
||||
validator.validate_input_path(str(test_file))
|
||||
|
||||
# Add directory to whitelist
|
||||
validator.add_allowed_directory(new_dir)
|
||||
|
||||
# Should succeed now
|
||||
result = validator.validate_input_path(str(test_file))
|
||||
assert result.exists()
|
||||
|
||||
test_file.unlink()
|
||||
|
||||
|
||||
class TestFileExtensionValidation:
|
||||
"""Test file extension validation"""
|
||||
|
||||
def test_allowed_read_extension(self, tmp_path):
|
||||
"""Test allowed read extensions are accepted"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
for ext in ['.md', '.txt', '.html', '.json']:
|
||||
test_file = tmp_path / f"test{ext}"
|
||||
test_file.write_text("data")
|
||||
|
||||
result = validator.validate_input_path(str(test_file))
|
||||
assert result.exists()
|
||||
|
||||
test_file.unlink()
|
||||
|
||||
def test_disallowed_read_extension(self, tmp_path):
|
||||
"""Test disallowed extensions are rejected for reading"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
dangerous_files = [
|
||||
"script.sh",
|
||||
"executable.exe",
|
||||
"code.py",
|
||||
"binary.bin",
|
||||
]
|
||||
|
||||
for filename in dangerous_files:
|
||||
test_file = tmp_path / filename
|
||||
test_file.write_text("data")
|
||||
|
||||
with pytest.raises(PathValidationError, match="not allowed for reading"):
|
||||
validator.validate_input_path(str(test_file))
|
||||
|
||||
test_file.unlink()
|
||||
|
||||
def test_allowed_write_extension(self, tmp_path):
|
||||
"""Test allowed write extensions are accepted"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
for ext in ['.md', '.html', '.db', '.log']:
|
||||
test_file = tmp_path / f"output{ext}"
|
||||
|
||||
result = validator.validate_output_path(str(test_file))
|
||||
assert result.parent.exists()
|
||||
|
||||
def test_disallowed_write_extension(self, tmp_path):
|
||||
"""Test disallowed extensions are rejected for writing"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
with pytest.raises(PathValidationError, match="not allowed for writing"):
|
||||
validator.validate_output_path(str(tmp_path / "output.exe"))
|
||||
|
||||
|
||||
class TestNullByteInjection:
|
||||
"""Test null byte injection prevention"""
|
||||
|
||||
def test_null_byte_in_path(self, tmp_path):
|
||||
"""Test null byte injection is blocked"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
malicious_paths = [
|
||||
"file.md\x00.exe",
|
||||
"file\x00.md",
|
||||
"\x00etc/passwd",
|
||||
]
|
||||
|
||||
for path in malicious_paths:
|
||||
with pytest.raises(PathValidationError, match="Dangerous pattern"):
|
||||
validator.validate_input_path(path)
|
||||
|
||||
|
||||
class TestNewlineInjection:
|
||||
"""Test newline injection prevention"""
|
||||
|
||||
def test_newline_in_path(self, tmp_path):
|
||||
"""Test newline injection is blocked"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
malicious_paths = [
|
||||
"file\n.md",
|
||||
"file.md\r\n",
|
||||
"file\r.md",
|
||||
]
|
||||
|
||||
for path in malicious_paths:
|
||||
with pytest.raises(PathValidationError, match="Dangerous pattern"):
|
||||
validator.validate_input_path(path)
|
||||
|
||||
|
||||
class TestOutputPathValidation:
|
||||
"""Test output path validation"""
|
||||
|
||||
def test_output_path_creates_parent(self, tmp_path):
|
||||
"""Test parent directory creation for output paths"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
output_path = tmp_path / "subdir" / "output.md"
|
||||
|
||||
result = validator.validate_output_path(str(output_path), create_parent=True)
|
||||
|
||||
assert result.parent.exists()
|
||||
assert result == output_path.resolve()
|
||||
|
||||
def test_output_path_no_create_parent(self, tmp_path):
|
||||
"""Test error when parent doesn't exist and create_parent=False"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
output_path = tmp_path / "nonexistent" / "output.md"
|
||||
|
||||
with pytest.raises(PathValidationError, match="Parent directory does not exist"):
|
||||
validator.validate_output_path(str(output_path), create_parent=False)
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and corner scenarios"""
|
||||
|
||||
def test_empty_path(self):
|
||||
"""Test empty path is rejected"""
|
||||
validator = PathValidator()
|
||||
|
||||
with pytest.raises(PathValidationError):
|
||||
validator.validate_input_path("")
|
||||
|
||||
def test_directory_instead_of_file(self, tmp_path):
|
||||
"""Test directory path is rejected (expect file)"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
test_dir = tmp_path / "testdir"
|
||||
test_dir.mkdir()
|
||||
|
||||
with pytest.raises(PathValidationError, match="not a file"):
|
||||
validator.validate_input_path(str(test_dir))
|
||||
|
||||
test_dir.rmdir()
|
||||
|
||||
def test_nonexistent_file(self, tmp_path):
|
||||
"""Test nonexistent file is rejected for reading"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
with pytest.raises(PathValidationError, match="does not exist"):
|
||||
validator.validate_input_path(str(tmp_path / "nonexistent.md"))
|
||||
|
||||
def test_case_insensitive_extension(self, tmp_path):
|
||||
"""Test extension matching is case-insensitive"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
test_file = tmp_path / "TEST.MD" # Uppercase extension
|
||||
test_file.write_text("data")
|
||||
|
||||
# Should succeed (case-insensitive)
|
||||
result = validator.validate_input_path(str(test_file))
|
||||
assert result.exists()
|
||||
|
||||
test_file.unlink()
|
||||
|
||||
|
||||
class TestGlobalValidator:
|
||||
"""Test global validator convenience functions"""
|
||||
|
||||
def test_global_validate_input_path(self, tmp_path):
|
||||
"""Test global validate_input_path function"""
|
||||
from utils.path_validator import get_validator
|
||||
|
||||
# Add tmp_path to global validator
|
||||
get_validator().add_allowed_directory(tmp_path)
|
||||
|
||||
test_file = tmp_path / "test.md"
|
||||
test_file.write_text("data")
|
||||
|
||||
result = validate_input_path(str(test_file))
|
||||
assert result.exists()
|
||||
|
||||
test_file.unlink()
|
||||
|
||||
def test_global_validate_output_path(self, tmp_path):
|
||||
"""Test global validate_output_path function"""
|
||||
from utils.path_validator import get_validator
|
||||
|
||||
get_validator().add_allowed_directory(tmp_path)
|
||||
|
||||
output_path = tmp_path / "output.md"
|
||||
|
||||
result = validate_output_path(str(output_path))
|
||||
assert result == output_path.resolve()
|
||||
|
||||
|
||||
class TestSecurityScenarios:
|
||||
"""Test realistic attack scenarios"""
|
||||
|
||||
def test_zipslip_attack(self, tmp_path):
|
||||
"""Test zipslip-style attack is blocked"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
# Zipslip: ../../../etc/passwd
|
||||
with pytest.raises(PathValidationError, match="Dangerous pattern"):
|
||||
validator.validate_input_path("../../../etc/passwd")
|
||||
|
||||
def test_windows_path_traversal(self, tmp_path):
|
||||
"""Test Windows-style path traversal is blocked"""
|
||||
validator = PathValidator(allowed_base_dirs={tmp_path})
|
||||
|
||||
malicious_paths = [
|
||||
"..\\..\\..\\windows\\system32",
|
||||
"C:\\..\\..\\etc\\passwd",
|
||||
]
|
||||
|
||||
for path in malicious_paths:
|
||||
with pytest.raises(PathValidationError):
|
||||
validator.validate_input_path(path)
|
||||
|
||||
def test_home_directory_expansion_safe(self, tmp_path):
|
||||
"""Test home directory expansion works safely"""
|
||||
# Create test file in actual home directory
|
||||
home = Path.home()
|
||||
test_file = home / "Documents" / "test_path_validator.md"
|
||||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file.write_text("test")
|
||||
|
||||
validator = PathValidator() # Uses default whitelist including ~/Documents
|
||||
|
||||
# Should work with ~ expansion
|
||||
result = validator.validate_input_path("~/Documents/test_path_validator.md")
|
||||
assert result.exists()
|
||||
|
||||
# Cleanup
|
||||
test_file.unlink()
|
||||
|
||||
|
||||
# Run tests with: pytest -v test_path_validator.py
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
Reference in New Issue
Block a user