feat(utils): add retry utilities with exponential backoff (#208)
Add retry_with_backoff() and retry_with_backoff_async() for network operations. Features: - Configurable max attempts (default: 3) - Exponential backoff with configurable base delay - Operation name for meaningful log messages - Both sync and async versions Addresses E2.6: Add retry logic for network failures Co-authored-by: Joseph Magly <1159087+jmagly@users.noreply.github.com>
This commit is contained in:
@@ -7,8 +7,14 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
import platform
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Dict, Union
|
||||
from typing import Optional, Tuple, Dict, Union, TypeVar, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def open_folder(folder_path: Union[str, Path]) -> bool:
|
||||
@@ -225,3 +231,113 @@ def read_reference_files(skill_dir: Union[str, Path], max_chars: int = 100000, p
|
||||
break
|
||||
|
||||
return references
|
||||
|
||||
|
||||
def retry_with_backoff(
|
||||
operation: Callable[[], T],
|
||||
max_attempts: int = 3,
|
||||
base_delay: float = 1.0,
|
||||
operation_name: str = "operation"
|
||||
) -> T:
|
||||
"""Retry an operation with exponential backoff.
|
||||
|
||||
Useful for network operations that may fail due to transient errors.
|
||||
Waits progressively longer between retries (exponential backoff).
|
||||
|
||||
Args:
|
||||
operation: Function to retry (takes no arguments, returns result)
|
||||
max_attempts: Maximum number of attempts (default: 3)
|
||||
base_delay: Base delay in seconds, doubles each retry (default: 1.0)
|
||||
operation_name: Name for logging purposes (default: "operation")
|
||||
|
||||
Returns:
|
||||
Result of successful operation
|
||||
|
||||
Raises:
|
||||
Exception: Last exception if all retries fail
|
||||
|
||||
Example:
|
||||
>>> def fetch_page():
|
||||
... response = requests.get(url, timeout=30)
|
||||
... response.raise_for_status()
|
||||
... return response.text
|
||||
>>> content = retry_with_backoff(fetch_page, max_attempts=3, operation_name=f"fetch {url}")
|
||||
"""
|
||||
last_exception: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
return operation()
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_attempts:
|
||||
delay = base_delay * (2 ** (attempt - 1))
|
||||
logger.warning(
|
||||
"%s failed (attempt %d/%d), retrying in %.1fs: %s",
|
||||
operation_name, attempt, max_attempts, delay, e
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
logger.error(
|
||||
"%s failed after %d attempts: %s",
|
||||
operation_name, max_attempts, e
|
||||
)
|
||||
|
||||
# This should always have a value, but mypy doesn't know that
|
||||
if last_exception is not None:
|
||||
raise last_exception
|
||||
raise RuntimeError(f"{operation_name} failed with no exception captured")
|
||||
|
||||
|
||||
async def retry_with_backoff_async(
|
||||
operation: Callable[[], T],
|
||||
max_attempts: int = 3,
|
||||
base_delay: float = 1.0,
|
||||
operation_name: str = "operation"
|
||||
) -> T:
|
||||
"""Async version of retry_with_backoff for async operations.
|
||||
|
||||
Args:
|
||||
operation: Async function to retry (takes no arguments, returns awaitable)
|
||||
max_attempts: Maximum number of attempts (default: 3)
|
||||
base_delay: Base delay in seconds, doubles each retry (default: 1.0)
|
||||
operation_name: Name for logging purposes (default: "operation")
|
||||
|
||||
Returns:
|
||||
Result of successful operation
|
||||
|
||||
Raises:
|
||||
Exception: Last exception if all retries fail
|
||||
|
||||
Example:
|
||||
>>> async def fetch_page():
|
||||
... response = await client.get(url, timeout=30.0)
|
||||
... response.raise_for_status()
|
||||
... return response.text
|
||||
>>> content = await retry_with_backoff_async(fetch_page, operation_name=f"fetch {url}")
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
last_exception: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
return await operation()
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_attempts:
|
||||
delay = base_delay * (2 ** (attempt - 1))
|
||||
logger.warning(
|
||||
"%s failed (attempt %d/%d), retrying in %.1fs: %s",
|
||||
operation_name, attempt, max_attempts, delay, e
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error(
|
||||
"%s failed after %d attempts: %s",
|
||||
operation_name, max_attempts, e
|
||||
)
|
||||
|
||||
if last_exception is not None:
|
||||
raise last_exception
|
||||
raise RuntimeError(f"{operation_name} failed with no exception captured")
|
||||
|
||||
@@ -17,7 +17,9 @@ from skill_seekers.cli.utils import (
|
||||
format_file_size,
|
||||
validate_skill_directory,
|
||||
validate_zip_file,
|
||||
print_upload_instructions
|
||||
print_upload_instructions,
|
||||
retry_with_backoff,
|
||||
retry_with_backoff_async
|
||||
)
|
||||
|
||||
|
||||
@@ -218,5 +220,119 @@ class TestPrintUploadInstructions(unittest.TestCase):
|
||||
self.fail(f"print_upload_instructions raised {e}")
|
||||
|
||||
|
||||
class TestRetryWithBackoff(unittest.TestCase):
|
||||
"""Test retry_with_backoff function"""
|
||||
|
||||
def test_successful_operation_first_try(self):
|
||||
"""Test operation that succeeds on first try"""
|
||||
call_count = 0
|
||||
|
||||
def operation():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return "success"
|
||||
|
||||
result = retry_with_backoff(operation, max_attempts=3)
|
||||
self.assertEqual(result, "success")
|
||||
self.assertEqual(call_count, 1)
|
||||
|
||||
def test_successful_operation_after_retry(self):
|
||||
"""Test operation that fails once then succeeds"""
|
||||
call_count = 0
|
||||
|
||||
def operation():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 2:
|
||||
raise ConnectionError("Temporary failure")
|
||||
return "success"
|
||||
|
||||
result = retry_with_backoff(operation, max_attempts=3, base_delay=0.01)
|
||||
self.assertEqual(result, "success")
|
||||
self.assertEqual(call_count, 2)
|
||||
|
||||
def test_all_retries_fail(self):
|
||||
"""Test operation that fails all retries"""
|
||||
call_count = 0
|
||||
|
||||
def operation():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise ConnectionError("Persistent failure")
|
||||
|
||||
with self.assertRaises(ConnectionError):
|
||||
retry_with_backoff(operation, max_attempts=3, base_delay=0.01)
|
||||
self.assertEqual(call_count, 3)
|
||||
|
||||
def test_exponential_backoff_timing(self):
|
||||
"""Test that delays follow exponential pattern"""
|
||||
import time
|
||||
|
||||
call_times = []
|
||||
|
||||
def operation():
|
||||
call_times.append(time.time())
|
||||
if len(call_times) < 3:
|
||||
raise ConnectionError("Fail")
|
||||
return "success"
|
||||
|
||||
retry_with_backoff(operation, max_attempts=3, base_delay=0.1)
|
||||
|
||||
# Check that delays are increasing (exponential)
|
||||
# First delay: ~0.1s, Second delay: ~0.2s
|
||||
delay1 = call_times[1] - call_times[0]
|
||||
delay2 = call_times[2] - call_times[1]
|
||||
|
||||
self.assertGreater(delay1, 0.05) # First delay at least base_delay/2
|
||||
self.assertGreater(delay2, delay1 * 1.5) # Second should be ~2x first
|
||||
|
||||
|
||||
class TestRetryWithBackoffAsync(unittest.TestCase):
|
||||
"""Test retry_with_backoff_async function"""
|
||||
|
||||
def test_async_successful_operation(self):
|
||||
"""Test async operation that succeeds"""
|
||||
import asyncio
|
||||
|
||||
async def operation():
|
||||
return "async success"
|
||||
|
||||
result = asyncio.run(
|
||||
retry_with_backoff_async(operation, max_attempts=3)
|
||||
)
|
||||
self.assertEqual(result, "async success")
|
||||
|
||||
def test_async_retry_then_success(self):
|
||||
"""Test async operation that fails then succeeds"""
|
||||
import asyncio
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def operation():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 2:
|
||||
raise ConnectionError("Async failure")
|
||||
return "async success"
|
||||
|
||||
result = asyncio.run(
|
||||
retry_with_backoff_async(operation, max_attempts=3, base_delay=0.01)
|
||||
)
|
||||
self.assertEqual(result, "async success")
|
||||
self.assertEqual(call_count, 2)
|
||||
|
||||
def test_async_all_retries_fail(self):
|
||||
"""Test async operation that fails all retries"""
|
||||
import asyncio
|
||||
|
||||
async def operation():
|
||||
raise ConnectionError("Persistent async failure")
|
||||
|
||||
with self.assertRaises(ConnectionError):
|
||||
asyncio.run(
|
||||
retry_with_backoff_async(operation, max_attempts=2, base_delay=0.01)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user