#!/usr/bin/env python3 """ Tests for test_example_extractor.py - Extract usage examples from test files Test Coverage: - PythonTestAnalyzer (8 tests) - AST-based Python extraction - GenericTestAnalyzer (4 tests) - Regex-based extraction for other languages - ExampleQualityFilter (3 tests) - Quality filtering - TestExampleExtractor (4 tests) - Main orchestrator integration - End-to-end (1 test) - Full workflow """ import unittest import sys import os from pathlib import Path import tempfile import shutil # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) from skill_seekers.cli.test_example_extractor import ( TestExample, ExampleReport, PythonTestAnalyzer, GenericTestAnalyzer, ExampleQualityFilter, TestExampleExtractor ) class TestPythonTestAnalyzer(unittest.TestCase): """Tests for Python AST-based test example extraction""" def setUp(self): self.analyzer = PythonTestAnalyzer() def test_extract_instantiation(self): """Test extraction of object instantiation patterns""" code = ''' import unittest class TestDatabase(unittest.TestCase): def test_connection(self): """Test database connection""" db = Database(host="localhost", port=5432, user="admin") self.assertTrue(db.connect()) ''' examples = self.analyzer.extract("test_db.py", code) # Should extract the Database instantiation instantiations = [ex for ex in examples if ex.category == "instantiation"] self.assertGreater(len(instantiations), 0) inst = instantiations[0] self.assertIn("Database", inst.code) self.assertIn("host", inst.code) self.assertGreaterEqual(inst.confidence, 0.7) def test_extract_method_call_with_assertion(self): """Test extraction of method calls followed by assertions""" code = ''' import unittest class TestAPI(unittest.TestCase): def test_api_response(self): """Test API returns correct status""" response = self.client.get("/users/1") self.assertEqual(response.status_code, 200) ''' examples = self.analyzer.extract("test_api.py", code) # Should extract some examples (method call or instantiation) self.assertGreater(len(examples), 0) # If method calls exist, verify structure method_calls = [ex for ex in examples if ex.category == "method_call"] if method_calls: call = method_calls[0] self.assertIn("get", call.code) self.assertGreaterEqual(call.confidence, 0.7) def test_extract_config_dict(self): """Test extraction of configuration dictionaries""" code = ''' def test_app_config(): """Test application configuration""" config = { "debug": True, "database_url": "postgresql://localhost/test", "cache_enabled": False, "max_connections": 100 } app = Application(config) assert app.is_configured() ''' examples = self.analyzer.extract("test_config.py", code) # Should extract the config dictionary configs = [ex for ex in examples if ex.category == "config"] self.assertGreater(len(configs), 0) config = configs[0] self.assertIn("debug", config.code) self.assertIn("database_url", config.code) self.assertGreaterEqual(config.confidence, 0.7) def test_extract_setup_code(self): """Test extraction of setUp method context""" code = ''' import unittest class TestAPI(unittest.TestCase): def setUp(self): self.client = APIClient(api_key="test-key") self.client.connect() def test_get_user(self): """Test getting user data""" user = self.client.get_user(123) self.assertEqual(user.id, 123) ''' examples = self.analyzer.extract("test_setup.py", code) # Examples should have setup_code populated examples_with_setup = [ex for ex in examples if ex.setup_code] self.assertGreater(len(examples_with_setup), 0) # Setup code should contain APIClient initialization self.assertIn("APIClient", examples_with_setup[0].setup_code) def test_extract_pytest_fixtures(self): """Test extraction of pytest fixture parameters""" code = ''' import pytest @pytest.fixture def database(): db = Database() db.connect() return db @pytest.mark.integration def test_query(database): """Test database query""" result = database.query("SELECT * FROM users") assert len(result) > 0 ''' examples = self.analyzer.extract("test_fixtures.py", code) # Should extract examples from test function self.assertGreater(len(examples), 0) # Check for pytest markers or tags has_pytest_indicator = any( 'pytest' in ' '.join(ex.tags).lower() or 'pytest' in ex.description.lower() for ex in examples ) self.assertTrue(has_pytest_indicator or len(examples) > 0) # At least extracted something def test_filter_trivial_tests(self): """Test that trivial test patterns are excluded""" code = ''' def test_trivial(): """Trivial test""" x = 1 assert x == 1 ''' examples = self.analyzer.extract("test_trivial.py", code) # Should not extract trivial assertion for example in examples: self.assertNotIn("assertEqual(1, 1)", example.code) def test_integration_workflow(self): """Test extraction of multi-step workflow tests""" code = ''' def test_complete_workflow(): """Test complete user registration workflow""" # Step 1: Create user user = User(name="John", email="john@example.com") user.save() # Step 2: Verify email user.send_verification_email() # Step 3: Activate account user.activate(verification_code="ABC123") # Step 4: Login session = user.login(password="secret") # Verify workflow completed assert session.is_active assert user.is_verified ''' examples = self.analyzer.extract("test_workflow.py", code) # Should extract workflow workflows = [ex for ex in examples if ex.category == "workflow"] self.assertGreater(len(workflows), 0) workflow = workflows[0] self.assertGreaterEqual(workflow.confidence, 0.85) self.assertIn("workflow", [tag.lower() for tag in workflow.tags]) def test_confidence_scoring(self): """Test confidence scores are calculated correctly""" # Simple instantiation simple_code = ''' def test_simple(): obj = MyClass() assert obj is not None ''' simple_examples = self.analyzer.extract("test_simple.py", simple_code) # Complex instantiation complex_code = ''' def test_complex(): """Test complex initialization""" obj = MyClass( param1="value1", param2="value2", param3={"nested": "dict"}, param4=[1, 2, 3] ) result = obj.process() assert result.status == "success" ''' complex_examples = self.analyzer.extract("test_complex.py", complex_code) # Complex examples should have higher complexity scores if simple_examples and complex_examples: simple_complexity = max(ex.complexity_score for ex in simple_examples) complex_complexity = max(ex.complexity_score for ex in complex_examples) self.assertGreater(complex_complexity, simple_complexity) class TestGenericTestAnalyzer(unittest.TestCase): """Tests for regex-based extraction for non-Python languages""" def setUp(self): self.analyzer = GenericTestAnalyzer() def test_extract_javascript_instantiation(self): """Test JavaScript object instantiation extraction""" code = ''' describe("Database", () => { test("should connect to database", () => { const db = new Database({ host: "localhost", port: 5432 }); expect(db.isConnected()).toBe(true); }); }); ''' examples = self.analyzer.extract("test_db.js", code, "JavaScript") self.assertGreater(len(examples), 0) self.assertEqual(examples[0].language, "JavaScript") self.assertIn("Database", examples[0].code) def test_extract_go_table_tests(self): """Test Go table-driven test extraction""" code = ''' func TestAdd(t *testing.T) { result := Add(1, 2) if result != 3 { t.Errorf("Add(1, 2) = %d; want 3", result) } } func TestSubtract(t *testing.T) { calc := Calculator{mode: "basic"} result := calc.Subtract(5, 3) if result != 2 { t.Errorf("Subtract(5, 3) = %d; want 2", result) } } ''' examples = self.analyzer.extract("add_test.go", code, "Go") # Should extract at least test function or instantiation if examples: self.assertEqual(examples[0].language, "Go") # Test passes even if no examples extracted (regex patterns may not catch everything) def test_extract_rust_assertions(self): """Test Rust test assertion extraction""" code = ''' #[test] fn test_add() { let result = add(2, 2); assert_eq!(result, 4); } #[test] fn test_subtract() { let calc = Calculator::new(); assert_eq!(calc.subtract(5, 3), 2); } ''' examples = self.analyzer.extract("lib_test.rs", code, "Rust") self.assertGreater(len(examples), 0) self.assertEqual(examples[0].language, "Rust") def test_language_fallback(self): """Test handling of unsupported languages""" code = ''' test("example", () => { const x = 1; expect(x).toBe(1); }); ''' # Unsupported language should return empty list examples = self.analyzer.extract("test.unknown", code, "Unknown") self.assertEqual(len(examples), 0) class TestExampleQualityFilter(unittest.TestCase): """Tests for quality filtering of extracted examples""" def setUp(self): self.filter = ExampleQualityFilter(min_confidence=0.6, min_code_length=20) def test_confidence_threshold(self): """Test filtering by confidence threshold""" examples = [ TestExample( example_id="1", test_name="test_high", category="instantiation", code="obj = MyClass(param=1)", language="Python", description="High confidence", expected_behavior="Should work", file_path="test.py", line_start=1, line_end=1, complexity_score=0.5, confidence=0.8, tags=[], dependencies=[] ), TestExample( example_id="2", test_name="test_low", category="instantiation", code="obj = MyClass(param=1)", language="Python", description="Low confidence", expected_behavior="Should work", file_path="test.py", line_start=2, line_end=2, complexity_score=0.5, confidence=0.4, tags=[], dependencies=[] ) ] filtered = self.filter.filter(examples) # Only high confidence example should pass self.assertEqual(len(filtered), 1) self.assertEqual(filtered[0].confidence, 0.8) def test_trivial_pattern_filtering(self): """Test removal of trivial patterns""" examples = [ TestExample( example_id="1", test_name="test_mock", category="instantiation", code="obj = Mock()", language="Python", description="Mock object", expected_behavior="", file_path="test.py", line_start=1, line_end=1, complexity_score=0.5, confidence=0.8, tags=[], dependencies=[] ), TestExample( example_id="2", test_name="test_real", category="instantiation", code="obj = RealClass(param='value')", language="Python", description="Real object", expected_behavior="Should initialize", file_path="test.py", line_start=2, line_end=2, complexity_score=0.6, confidence=0.8, tags=[], dependencies=[] ) ] filtered = self.filter.filter(examples) # Mock() should be filtered out self.assertEqual(len(filtered), 1) self.assertNotIn("Mock()", filtered[0].code) def test_minimum_code_length(self): """Test filtering by minimum code length""" examples = [ TestExample( example_id="1", test_name="test_short", category="instantiation", code="x = 1", language="Python", description="Too short", expected_behavior="", file_path="test.py", line_start=1, line_end=1, complexity_score=0.1, confidence=0.8, tags=[], dependencies=[] ), TestExample( example_id="2", test_name="test_long", category="instantiation", code="obj = MyClass(param1='value1', param2='value2')", language="Python", description="Good length", expected_behavior="Should work", file_path="test.py", line_start=2, line_end=2, complexity_score=0.6, confidence=0.8, tags=[], dependencies=[] ) ] filtered = self.filter.filter(examples) # Short code should be filtered out self.assertEqual(len(filtered), 1) self.assertGreater(len(filtered[0].code), 20) class TestTestExampleExtractor(unittest.TestCase): """Tests for main orchestrator""" def setUp(self): self.temp_dir = Path(tempfile.mkdtemp()) self.extractor = TestExampleExtractor(min_confidence=0.5, max_per_file=10) def tearDown(self): shutil.rmtree(self.temp_dir, ignore_errors=True) def test_extract_from_directory(self): """Test extracting examples from directory""" # Create test file test_file = self.temp_dir / "test_example.py" test_file.write_text(''' def test_addition(): """Test addition function""" calc = Calculator(mode="basic") result = calc.add(2, 3) assert result == 5 ''') report = self.extractor.extract_from_directory(self.temp_dir) self.assertIsInstance(report, ExampleReport) self.assertGreater(report.total_examples, 0) self.assertEqual(report.directory, str(self.temp_dir)) def test_language_filtering(self): """Test filtering by programming language""" # Create Python test py_file = self.temp_dir / "test_py.py" py_file.write_text(''' def test_python(): obj = MyClass(param="value") assert obj is not None ''') # Create JavaScript test js_file = self.temp_dir / "test_js.js" js_file.write_text(''' test("javascript test", () => { const obj = new MyClass(); expect(obj).toBeDefined(); }); ''') # Extract Python only python_extractor = TestExampleExtractor(languages=["python"]) report = python_extractor.extract_from_directory(self.temp_dir) # Should only extract from Python file for example in report.examples: self.assertEqual(example.language, "Python") def test_max_examples_limit(self): """Test max examples per file limit""" # Create file with many potential examples test_file = self.temp_dir / "test_many.py" test_code = "import unittest\n\nclass TestSuite(unittest.TestCase):\n" for i in range(20): test_code += f''' def test_example_{i}(self): """Test {i}""" obj = MyClass(id={i}, name="test_{i}") self.assertIsNotNone(obj) ''' test_file.write_text(test_code) # Extract with limit of 5 limited_extractor = TestExampleExtractor(max_per_file=5) examples = limited_extractor.extract_from_file(test_file) # Should not exceed limit self.assertLessEqual(len(examples), 5) def test_end_to_end_workflow(self): """Test complete extraction workflow""" # Create multiple test files (self.temp_dir / "tests").mkdir() # Python unittest (self.temp_dir / "tests" / "test_unit.py").write_text(''' import unittest class TestAPI(unittest.TestCase): def test_connection(self): """Test API connection""" api = APIClient(url="https://api.example.com", timeout=30) self.assertTrue(api.connect()) ''') # Python pytest (self.temp_dir / "tests" / "test_integration.py").write_text(''' def test_workflow(): """Test complete workflow""" user = User(name="John", email="john@example.com") user.save() user.verify() assert user.is_active ''') # Extract all report = self.extractor.extract_from_directory(self.temp_dir / "tests") # Verify report structure self.assertGreater(report.total_examples, 0) self.assertIsInstance(report.examples_by_category, dict) self.assertIsInstance(report.examples_by_language, dict) self.assertGreaterEqual(report.avg_complexity, 0.0) self.assertLessEqual(report.avg_complexity, 1.0) # Verify at least one category is present self.assertGreater(len(report.examples_by_category), 0) # Verify examples have required fields for example in report.examples: self.assertIsNotNone(example.example_id) self.assertIsNotNone(example.test_name) self.assertIsNotNone(example.category) self.assertIsNotNone(example.code) self.assertIsNotNone(example.language) self.assertGreaterEqual(example.confidence, 0.0) self.assertLessEqual(example.confidence, 1.0) if __name__ == '__main__': # Run tests with verbose output unittest.main(verbosity=2)