From 0d0eda7149a645383735e4c5139bd9295d625f79 Mon Sep 17 00:00:00 2001 From: Joseph Magly <1159087+jmagly@users.noreply.github.com> Date: Sun, 21 Dec 2025 14:31:38 -0500 Subject: [PATCH] feat(utils): add retry utilities with exponential backoff (#208) Add retry_with_backoff() and retry_with_backoff_async() for network operations. Features: - Configurable max attempts (default: 3) - Exponential backoff with configurable base delay - Operation name for meaningful log messages - Both sync and async versions Addresses E2.6: Add retry logic for network failures Co-authored-by: Joseph Magly <1159087+jmagly@users.noreply.github.com> --- src/skill_seekers/cli/utils.py | 118 ++++++++++++++++++++++++++++++++- tests/test_utilities.py | 118 ++++++++++++++++++++++++++++++++- 2 files changed, 234 insertions(+), 2 deletions(-) diff --git a/src/skill_seekers/cli/utils.py b/src/skill_seekers/cli/utils.py index 64612c2..dd870e5 100755 --- a/src/skill_seekers/cli/utils.py +++ b/src/skill_seekers/cli/utils.py @@ -7,8 +7,14 @@ import os import sys import subprocess import platform +import time +import logging from pathlib import Path -from typing import Optional, Tuple, Dict, Union +from typing import Optional, Tuple, Dict, Union, TypeVar, Callable + +logger = logging.getLogger(__name__) + +T = TypeVar('T') def open_folder(folder_path: Union[str, Path]) -> bool: @@ -225,3 +231,113 @@ def read_reference_files(skill_dir: Union[str, Path], max_chars: int = 100000, p break return references + + +def retry_with_backoff( + operation: Callable[[], T], + max_attempts: int = 3, + base_delay: float = 1.0, + operation_name: str = "operation" +) -> T: + """Retry an operation with exponential backoff. + + Useful for network operations that may fail due to transient errors. + Waits progressively longer between retries (exponential backoff). + + Args: + operation: Function to retry (takes no arguments, returns result) + max_attempts: Maximum number of attempts (default: 3) + base_delay: Base delay in seconds, doubles each retry (default: 1.0) + operation_name: Name for logging purposes (default: "operation") + + Returns: + Result of successful operation + + Raises: + Exception: Last exception if all retries fail + + Example: + >>> def fetch_page(): + ... response = requests.get(url, timeout=30) + ... response.raise_for_status() + ... return response.text + >>> content = retry_with_backoff(fetch_page, max_attempts=3, operation_name=f"fetch {url}") + """ + last_exception: Optional[Exception] = None + + for attempt in range(1, max_attempts + 1): + try: + return operation() + except Exception as e: + last_exception = e + if attempt < max_attempts: + delay = base_delay * (2 ** (attempt - 1)) + logger.warning( + "%s failed (attempt %d/%d), retrying in %.1fs: %s", + operation_name, attempt, max_attempts, delay, e + ) + time.sleep(delay) + else: + logger.error( + "%s failed after %d attempts: %s", + operation_name, max_attempts, e + ) + + # This should always have a value, but mypy doesn't know that + if last_exception is not None: + raise last_exception + raise RuntimeError(f"{operation_name} failed with no exception captured") + + +async def retry_with_backoff_async( + operation: Callable[[], T], + max_attempts: int = 3, + base_delay: float = 1.0, + operation_name: str = "operation" +) -> T: + """Async version of retry_with_backoff for async operations. + + Args: + operation: Async function to retry (takes no arguments, returns awaitable) + max_attempts: Maximum number of attempts (default: 3) + base_delay: Base delay in seconds, doubles each retry (default: 1.0) + operation_name: Name for logging purposes (default: "operation") + + Returns: + Result of successful operation + + Raises: + Exception: Last exception if all retries fail + + Example: + >>> async def fetch_page(): + ... response = await client.get(url, timeout=30.0) + ... response.raise_for_status() + ... return response.text + >>> content = await retry_with_backoff_async(fetch_page, operation_name=f"fetch {url}") + """ + import asyncio + + last_exception: Optional[Exception] = None + + for attempt in range(1, max_attempts + 1): + try: + return await operation() + except Exception as e: + last_exception = e + if attempt < max_attempts: + delay = base_delay * (2 ** (attempt - 1)) + logger.warning( + "%s failed (attempt %d/%d), retrying in %.1fs: %s", + operation_name, attempt, max_attempts, delay, e + ) + await asyncio.sleep(delay) + else: + logger.error( + "%s failed after %d attempts: %s", + operation_name, max_attempts, e + ) + + if last_exception is not None: + raise last_exception + raise RuntimeError(f"{operation_name} failed with no exception captured") diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 6026e7b..8f7f360 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -17,7 +17,9 @@ from skill_seekers.cli.utils import ( format_file_size, validate_skill_directory, validate_zip_file, - print_upload_instructions + print_upload_instructions, + retry_with_backoff, + retry_with_backoff_async ) @@ -218,5 +220,119 @@ class TestPrintUploadInstructions(unittest.TestCase): self.fail(f"print_upload_instructions raised {e}") +class TestRetryWithBackoff(unittest.TestCase): + """Test retry_with_backoff function""" + + def test_successful_operation_first_try(self): + """Test operation that succeeds on first try""" + call_count = 0 + + def operation(): + nonlocal call_count + call_count += 1 + return "success" + + result = retry_with_backoff(operation, max_attempts=3) + self.assertEqual(result, "success") + self.assertEqual(call_count, 1) + + def test_successful_operation_after_retry(self): + """Test operation that fails once then succeeds""" + call_count = 0 + + def operation(): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise ConnectionError("Temporary failure") + return "success" + + result = retry_with_backoff(operation, max_attempts=3, base_delay=0.01) + self.assertEqual(result, "success") + self.assertEqual(call_count, 2) + + def test_all_retries_fail(self): + """Test operation that fails all retries""" + call_count = 0 + + def operation(): + nonlocal call_count + call_count += 1 + raise ConnectionError("Persistent failure") + + with self.assertRaises(ConnectionError): + retry_with_backoff(operation, max_attempts=3, base_delay=0.01) + self.assertEqual(call_count, 3) + + def test_exponential_backoff_timing(self): + """Test that delays follow exponential pattern""" + import time + + call_times = [] + + def operation(): + call_times.append(time.time()) + if len(call_times) < 3: + raise ConnectionError("Fail") + return "success" + + retry_with_backoff(operation, max_attempts=3, base_delay=0.1) + + # Check that delays are increasing (exponential) + # First delay: ~0.1s, Second delay: ~0.2s + delay1 = call_times[1] - call_times[0] + delay2 = call_times[2] - call_times[1] + + self.assertGreater(delay1, 0.05) # First delay at least base_delay/2 + self.assertGreater(delay2, delay1 * 1.5) # Second should be ~2x first + + +class TestRetryWithBackoffAsync(unittest.TestCase): + """Test retry_with_backoff_async function""" + + def test_async_successful_operation(self): + """Test async operation that succeeds""" + import asyncio + + async def operation(): + return "async success" + + result = asyncio.run( + retry_with_backoff_async(operation, max_attempts=3) + ) + self.assertEqual(result, "async success") + + def test_async_retry_then_success(self): + """Test async operation that fails then succeeds""" + import asyncio + + call_count = 0 + + async def operation(): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise ConnectionError("Async failure") + return "async success" + + result = asyncio.run( + retry_with_backoff_async(operation, max_attempts=3, base_delay=0.01) + ) + self.assertEqual(result, "async success") + self.assertEqual(call_count, 2) + + def test_async_all_retries_fail(self): + """Test async operation that fails all retries""" + import asyncio + + async def operation(): + raise ConnectionError("Persistent async failure") + + with self.assertRaises(ConnectionError): + asyncio.run( + retry_with_backoff_async(operation, max_attempts=2, base_delay=0.01) + ) + + if __name__ == '__main__': unittest.main()