"""
Tests for reliability module - Production hardening features.

Tests:
- Rate limiting with exponential backoff
- Retry logic with transient error detection
- Circuit breaker pattern
- Fallback model selection
- Graceful shutdown handling
"""

import pytest
import asyncio
from datetime import datetime, timedelta
from unittest.mock import Mock, AsyncMock, patch
import signal

from agent_orchestrator.reliability import (
    # Rate Limiter
    RateLimiter,
    RateLimitConfig,
    RateLimitState,
    RateLimiterRegistry,
    DEFAULT_RATE_LIMITS,
    get_rate_limiter_registry,
    with_rate_limit,
    # Retry
    RetryConfig,
    RetryManager,
    FallbackModel,
    FallbackConfig,
    FallbackRequired,
    CircuitBreaker,
    retry,
    is_transient_exception,
    is_rate_limit_exception,
    get_retry_manager,
    set_retry_manager,
    DEFAULT_FALLBACKS,
    TRANSIENT_EXCEPTIONS,
    # Shutdown
    GracefulShutdown,
    ShutdownConfig,
    ShutdownTask,
    ShutdownResult,
    ShutdownState,
    get_shutdown_manager,
    set_shutdown_manager,
    stop_agents,
    cancel_pending_tasks,
    close_connections,
    persist_state,
    flush_logs,
)


# =============================================================================
# Rate Limiter Tests
# =============================================================================

class TestRateLimitConfig:
    """Tests for RateLimitConfig."""

    def test_default_config(self):
        """Test default configuration values."""
        config = RateLimitConfig()
        assert config.requests_per_minute == 60
        assert config.tokens_per_minute == 100000
        assert config.concurrent_requests == 10

    def test_custom_config(self):
        """Test custom configuration."""
        config = RateLimitConfig(
            requests_per_minute=30,
            tokens_per_minute=50000,
            concurrent_requests=5,
        )
        assert config.requests_per_minute == 30
        assert config.tokens_per_minute == 50000
        assert config.concurrent_requests == 5


class TestRateLimiter:
    """Tests for RateLimiter."""

    @pytest.fixture
    def limiter(self):
        """Create a rate limiter for testing."""
        config = RateLimitConfig(
            requests_per_minute=10,
            tokens_per_minute=1000,
            concurrent_requests=3,
        )
        return RateLimiter(provider="test", config=config)

    @pytest.mark.asyncio
    async def test_acquire_no_wait(self, limiter):
        """Test acquire when under limits."""
        wait_time = await limiter.acquire(tokens=100)
        assert wait_time == 0.0

    @pytest.mark.asyncio
    async def test_acquire_tracks_requests(self, limiter):
        """Test that acquire tracks request count."""
        for _ in range(5):
            await limiter.acquire(tokens=50)

        state = limiter.state
        assert state.current_requests == 5
        assert state.current_tokens == 250

    @pytest.mark.asyncio
    async def test_concurrent_request_limit(self, limiter):
        """Test concurrent request limiting."""
        # Max concurrent is 3
        await limiter.acquire()
        await limiter.acquire()
        await limiter.acquire()

        # Fourth should return wait time
        wait_time = await limiter.acquire()
        assert wait_time > 0

    @pytest.mark.asyncio
    async def test_record_success(self, limiter):
        """Test recording successful request."""
        await limiter.acquire()
        await limiter.record_success()

        state = limiter.state
        assert state.success_count == 1

    @pytest.mark.asyncio
    async def test_record_failure_increases_backoff(self, limiter):
        """Test that failures increase backoff."""
        await limiter.acquire()

        # Record a rate limit failure
        backoff = await limiter.record_failure(is_rate_limit=True)
        assert backoff > 0

        state = limiter.state
        assert state.error_count == 1
        assert state.current_backoff > 0

    @pytest.mark.asyncio
    async def test_exponential_backoff(self, limiter):
        """Test exponential backoff on repeated failures."""
        backoffs = []

        for _ in range(3):
            await limiter.acquire()
            backoff = await limiter.record_failure(is_rate_limit=True)
            backoffs.append(backoff)

        # Each backoff should be larger
        assert backoffs[1] > backoffs[0]
        assert backoffs[2] > backoffs[1]

    @pytest.mark.asyncio
    async def test_backoff_capped_at_max(self, limiter):
        """Test that backoff is capped at max_backoff."""
        # Record many failures
        for _ in range(20):
            await limiter.acquire()
            await limiter.record_failure(is_rate_limit=True)

        state = limiter.state
        assert state.current_backoff <= limiter.config.max_backoff

    @pytest.mark.asyncio
    async def test_success_reduces_backoff(self, limiter):
        """Test that success reduces backoff."""
        # Build up some backoff
        for _ in range(3):
            await limiter.acquire()
            await limiter.record_failure(is_rate_limit=True)

        initial_backoff = limiter.state.current_backoff

        # Record successes
        for _ in range(3):
            await limiter.record_success()

        assert limiter.state.current_backoff < initial_backoff


class TestRateLimiterRegistry:
    """Tests for RateLimiterRegistry."""

    def test_get_creates_limiter(self):
        """Test that get creates a new limiter."""
        registry = RateLimiterRegistry()
        limiter = registry.get("anthropic")

        assert limiter is not None
        assert limiter.provider == "anthropic"

    def test_get_returns_same_limiter(self):
        """Test that get returns the same limiter instance."""
        registry = RateLimiterRegistry()
        limiter1 = registry.get("anthropic")
        limiter2 = registry.get("anthropic")

        assert limiter1 is limiter2

    def test_default_configs(self):
        """Test that default configs are applied."""
        registry = RateLimiterRegistry()
        limiter = registry.get("anthropic")

        assert limiter.config.requests_per_minute == DEFAULT_RATE_LIMITS["anthropic"].requests_per_minute


# =============================================================================
# Retry Tests
# =============================================================================

class TestRetryConfig:
    """Tests for RetryConfig."""

    def test_default_config(self):
        """Test default retry configuration."""
        config = RetryConfig()
        assert config.max_retries == 3
        assert config.initial_delay == 1.0
        assert config.max_delay == 60.0
        assert config.exponential_base == 2.0

    def test_custom_config(self):
        """Test custom retry configuration."""
        config = RetryConfig(
            max_retries=5,
            initial_delay=0.5,
            max_delay=30.0,
        )
        assert config.max_retries == 5
        assert config.initial_delay == 0.5
        assert config.max_delay == 30.0


class TestTransientExceptionDetection:
    """Tests for transient exception detection."""

    def test_rate_limit_is_transient(self):
        """Test that rate limit errors are detected as transient."""
        class RateLimitError(Exception):
            pass

        assert is_transient_exception(RateLimitError())

    def test_connection_error_is_transient(self):
        """Test that connection errors are transient."""
        assert is_transient_exception(ConnectionError("Connection refused"))

    def test_timeout_is_transient(self):
        """Test that timeout errors are transient."""
        assert is_transient_exception(TimeoutError("Request timed out"))

    def test_transient_message_patterns(self):
        """Test that transient error message patterns are detected."""
        class CustomError(Exception):
            pass

        assert is_transient_exception(CustomError("rate limit exceeded"))
        assert is_transient_exception(CustomError("too many requests"))
        assert is_transient_exception(CustomError("service unavailable"))
        assert is_transient_exception(CustomError("temporarily unavailable"))

    def test_non_transient_error(self):
        """Test that non-transient errors are not detected."""
        class ValidationError(Exception):
            pass

        assert not is_transient_exception(ValidationError("Invalid input"))

    def test_is_rate_limit_exception(self):
        """Test rate limit specific detection."""
        class RateLimitError(Exception):
            pass

        assert is_rate_limit_exception(RateLimitError())
        assert is_rate_limit_exception(Exception("429 Too Many Requests"))
        assert is_rate_limit_exception(Exception("quota exceeded"))
        assert not is_rate_limit_exception(Exception("internal error"))


class TestRetryDecorator:
    """Tests for retry decorator."""

    @pytest.mark.asyncio
    async def test_successful_call_no_retry(self):
        """Test that successful calls don't retry."""
        call_count = 0

        @retry(max_retries=3)
        async def successful_func():
            nonlocal call_count
            call_count += 1
            return "success"

        result = await successful_func()

        assert result == "success"
        assert call_count == 1

    @pytest.mark.asyncio
    async def test_retry_on_transient_error(self):
        """Test that transient errors trigger retry."""
        call_count = 0

        @retry(max_retries=3, initial_delay=0.01)
        async def flaky_func():
            nonlocal call_count
            call_count += 1
            if call_count < 3:
                raise ConnectionError("Connection refused")
            return "success"

        result = await flaky_func()

        assert result == "success"
        assert call_count == 3

    @pytest.mark.asyncio
    async def test_max_retries_exceeded(self):
        """Test that max retries are respected."""
        call_count = 0

        @retry(max_retries=2, initial_delay=0.01)
        async def always_fails():
            nonlocal call_count
            call_count += 1
            raise ConnectionError("Always fails")

        with pytest.raises(ConnectionError):
            await always_fails()

        assert call_count == 3  # Initial + 2 retries

    @pytest.mark.asyncio
    async def test_no_retry_on_fatal_error(self):
        """Test that fatal errors don't retry."""
        call_count = 0

        class AuthError(Exception):
            pass

        @retry(max_retries=3, initial_delay=0.01)
        async def auth_error_func():
            nonlocal call_count
            call_count += 1
            raise AuthError("Invalid credentials")

        with pytest.raises(AuthError):
            await auth_error_func()

        assert call_count == 1  # No retries

    @pytest.mark.asyncio
    async def test_retry_callback(self):
        """Test that retry callback is called."""
        retries = []

        def on_retry(exc, attempt):
            retries.append((type(exc).__name__, attempt))

        @retry(max_retries=2, initial_delay=0.01, on_retry=on_retry)
        async def flaky_func():
            if len(retries) < 2:
                raise ConnectionError("Temporary failure")
            return "success"

        await flaky_func()

        assert len(retries) == 2
        assert retries[0] == ("ConnectionError", 1)
        assert retries[1] == ("ConnectionError", 2)


class TestCircuitBreaker:
    """Tests for CircuitBreaker."""

    @pytest.fixture
    def breaker(self):
        """Create a circuit breaker for testing."""
        return CircuitBreaker(
            failure_threshold=3,
            recovery_timeout=0.1,
            half_open_requests=2,
        )

    @pytest.mark.asyncio
    async def test_initial_state_closed(self, breaker):
        """Test that initial state is closed."""
        assert breaker.state == CircuitBreaker.State.CLOSED
        assert await breaker.can_execute()

    @pytest.mark.asyncio
    async def test_opens_after_failures(self, breaker):
        """Test that circuit opens after threshold failures."""
        for _ in range(3):
            await breaker.record_failure()

        assert breaker.state == CircuitBreaker.State.OPEN
        assert not await breaker.can_execute()

    @pytest.mark.asyncio
    async def test_half_open_after_timeout(self, breaker):
        """Test transition to half-open after timeout."""
        # Trip the circuit
        for _ in range(3):
            await breaker.record_failure()

        # Wait for recovery timeout
        await asyncio.sleep(0.15)

        # Should transition to half-open
        can_execute = await breaker.can_execute()
        assert can_execute
        assert breaker.state == CircuitBreaker.State.HALF_OPEN

    @pytest.mark.asyncio
    async def test_closes_after_half_open_success(self, breaker):
        """Test that circuit closes after successful half-open requests."""
        # Trip the circuit
        for _ in range(3):
            await breaker.record_failure()

        # Wait for timeout
        await asyncio.sleep(0.15)

        # Successful half-open requests
        await breaker.can_execute()
        await breaker.record_success()
        await breaker.can_execute()
        await breaker.record_success()

        assert breaker.state == CircuitBreaker.State.CLOSED

    @pytest.mark.asyncio
    async def test_reopens_on_half_open_failure(self, breaker):
        """Test that circuit reopens on half-open failure."""
        # Trip the circuit
        for _ in range(3):
            await breaker.record_failure()

        # Wait for timeout
        await asyncio.sleep(0.15)

        # Enter half-open
        await breaker.can_execute()

        # Fail in half-open
        await breaker.record_failure()

        assert breaker.state == CircuitBreaker.State.OPEN

    def test_get_stats(self, breaker):
        """Test getting circuit breaker statistics."""
        stats = breaker.get_stats()

        assert "state" in stats
        assert "failures" in stats
        assert "successes" in stats
        assert stats["state"] == CircuitBreaker.State.CLOSED


class TestFallbackModel:
    """Tests for FallbackModel."""

    def test_fallback_model_creation(self):
        """Test creating a fallback model."""
        fallback = FallbackModel(
            provider="anthropic",
            model="claude-3-haiku",
            priority=1,
            on_rate_limit=True,
        )

        assert fallback.provider == "anthropic"
        assert fallback.model == "claude-3-haiku"
        assert fallback.priority == 1
        assert fallback.on_rate_limit is True

    def test_default_fallbacks_exist(self):
        """Test that default fallbacks are configured."""
        assert "claude-3-opus" in DEFAULT_FALLBACKS
        assert "claude-3-sonnet" in DEFAULT_FALLBACKS
        assert "gpt-4o" in DEFAULT_FALLBACKS


class TestRetryManager:
    """Tests for RetryManager."""

    @pytest.fixture
    def manager(self):
        """Create a retry manager for testing."""
        return RetryManager()

    def test_get_circuit_breaker(self, manager):
        """Test getting circuit breaker for provider."""
        breaker = manager.get_circuit_breaker("anthropic")

        assert breaker is not None
        assert isinstance(breaker, CircuitBreaker)

    def test_get_same_circuit_breaker(self, manager):
        """Test that same circuit breaker is returned."""
        breaker1 = manager.get_circuit_breaker("anthropic")
        breaker2 = manager.get_circuit_breaker("anthropic")

        assert breaker1 is breaker2

    def test_get_fallback_model_rate_limit(self, manager):
        """Test getting fallback model for rate limit error."""
        class RateLimitError(Exception):
            pass

        fallback = manager.get_fallback_model("claude-3-opus", RateLimitError())

        assert fallback is not None
        assert fallback.model == "claude-3-sonnet"

    def test_get_fallback_model_none_for_unknown(self, manager):
        """Test that no fallback is returned for unknown model."""
        fallback = manager.get_fallback_model("unknown-model", Exception())

        assert fallback is None

    @pytest.mark.asyncio
    async def test_execute_with_retry_success(self, manager):
        """Test successful execution with retry manager."""
        async def successful_func():
            return "result"

        result = await manager.execute_with_retry(
            successful_func,
            provider="test",
        )

        assert result == "result"

    @pytest.mark.asyncio
    async def test_execute_with_retry_retries(self, manager):
        """Test that execute_with_retry retries on transient errors."""
        manager.retry_config = RetryConfig(
            max_retries=2,
            initial_delay=0.01,
        )

        call_count = 0

        async def flaky_func():
            nonlocal call_count
            call_count += 1
            if call_count < 2:
                raise ConnectionError("Temporary")
            return "success"

        result = await manager.execute_with_retry(
            flaky_func,
            provider="test",
        )

        assert result == "success"
        assert call_count == 2

    @pytest.mark.asyncio
    async def test_fallback_required_exception(self, manager):
        """Test that FallbackRequired is raised after max retries."""
        manager.retry_config = RetryConfig(
            max_retries=0,  # No retries
            initial_delay=0.01,
        )

        async def rate_limited():
            raise Exception("rate limit exceeded")

        with pytest.raises(FallbackRequired) as exc_info:
            await manager.execute_with_retry(
                rate_limited,
                provider="anthropic",
                model="claude-3-opus",
            )

        assert exc_info.value.fallback.model == "claude-3-sonnet"


# =============================================================================
# Graceful Shutdown Tests
# =============================================================================

class TestShutdownConfig:
    """Tests for ShutdownConfig."""

    def test_default_config(self):
        """Test default shutdown configuration."""
        config = ShutdownConfig()
        assert config.grace_period_seconds == 30.0
        assert config.force_after_grace_period is True
        assert config.persist_state is True

    def test_custom_config(self):
        """Test custom shutdown configuration."""
        config = ShutdownConfig(
            grace_period_seconds=60.0,
            force_after_grace_period=False,
        )
        assert config.grace_period_seconds == 60.0
        assert config.force_after_grace_period is False


class TestShutdownTask:
    """Tests for ShutdownTask."""

    def test_task_creation(self):
        """Test creating a shutdown task."""
        task = ShutdownTask(
            name="stop_agents",
            callback=lambda: None,
            priority=1,
            timeout_seconds=10.0,
            required=True,
        )

        assert task.name == "stop_agents"
        assert task.priority == 1
        assert task.required is True


class TestGracefulShutdown:
    """Tests for GracefulShutdown."""

    @pytest.fixture
    def shutdown_manager(self):
        """Create a shutdown manager for testing."""
        config = ShutdownConfig(
            grace_period_seconds=1.0,
            check_interval_seconds=0.1,
        )
        return GracefulShutdown(config=config)

    def test_initial_state_running(self, shutdown_manager):
        """Test that initial state is running."""
        assert shutdown_manager.state == ShutdownState.RUNNING
        assert not shutdown_manager.is_shutting_down

    def test_register_task(self, shutdown_manager):
        """Test registering a shutdown task."""
        shutdown_manager.register_task(
            name="test_task",
            callback=lambda: None,
            priority=5,
        )

        assert len(shutdown_manager._shutdown_tasks) == 1
        assert shutdown_manager._shutdown_tasks[0].name == "test_task"

    def test_track_work(self, shutdown_manager):
        """Test tracking in-flight work."""
        shutdown_manager.track_work("work-1")
        shutdown_manager.track_work("work-2")

        assert shutdown_manager.get_in_flight_count() == 2

    def test_complete_work(self, shutdown_manager):
        """Test completing in-flight work."""
        shutdown_manager.track_work("work-1")
        shutdown_manager.complete_work("work-1")

        assert shutdown_manager.get_in_flight_count() == 0

    @pytest.mark.asyncio
    async def test_shutdown_runs_tasks(self, shutdown_manager):
        """Test that shutdown runs registered tasks."""
        results = []

        async def task1():
            results.append("task1")

        async def task2():
            results.append("task2")

        shutdown_manager.register_task("task1", task1, priority=1)
        shutdown_manager.register_task("task2", task2, priority=2)

        result = await shutdown_manager.shutdown()

        assert result.success
        assert result.tasks_completed == 2
        assert results == ["task1", "task2"]  # Priority order

    @pytest.mark.asyncio
    async def test_shutdown_waits_for_in_flight(self, shutdown_manager):
        """Test that shutdown waits for in-flight work."""
        shutdown_manager.track_work("work-1")

        async def complete_work():
            await asyncio.sleep(0.2)
            shutdown_manager.complete_work("work-1")

        asyncio.create_task(complete_work())

        result = await shutdown_manager.shutdown()

        assert result.success
        assert shutdown_manager.get_in_flight_count() == 0

    @pytest.mark.asyncio
    async def test_shutdown_timeout_forces_completion(self, shutdown_manager):
        """Test that shutdown forces completion after timeout."""
        shutdown_manager.config.grace_period_seconds = 0.1
        shutdown_manager.track_work("stuck-work")

        result = await shutdown_manager.shutdown()

        assert result.was_forced
        assert shutdown_manager.get_in_flight_count() == 1  # Still tracked

    @pytest.mark.asyncio
    async def test_shutdown_handles_task_failure(self, shutdown_manager):
        """Test that shutdown handles task failures."""
        async def failing_task():
            raise RuntimeError("Task failed")

        shutdown_manager.register_task(
            "failing",
            failing_task,
            required=False,
        )

        result = await shutdown_manager.shutdown()

        assert result.success  # Non-required task doesn't block
        assert result.tasks_skipped == 1
        assert "failing" in result.errors[0]

    @pytest.mark.asyncio
    async def test_required_task_failure_blocks_success(self, shutdown_manager):
        """Test that required task failure marks shutdown as failed."""
        async def required_failing_task():
            raise RuntimeError("Critical failure")

        shutdown_manager.register_task(
            "critical",
            required_failing_task,
            required=True,
        )

        result = await shutdown_manager.shutdown()

        assert not result.success
        assert result.tasks_failed == 1

    @pytest.mark.asyncio
    async def test_shutdown_prevents_double_shutdown(self, shutdown_manager):
        """Test that shutdown can only be called once."""
        result1 = await shutdown_manager.shutdown()
        result2 = await shutdown_manager.shutdown()

        assert result1.success
        assert not result2.success
        assert "already in progress" in result2.errors[0]

    @pytest.mark.asyncio
    async def test_pre_shutdown_callbacks(self, shutdown_manager):
        """Test that pre-shutdown callbacks are called."""
        called = []

        def pre_shutdown():
            called.append("pre")

        shutdown_manager.on_pre_shutdown(pre_shutdown)

        await shutdown_manager.shutdown()

        assert "pre" in called

    @pytest.mark.asyncio
    async def test_post_shutdown_callbacks(self, shutdown_manager):
        """Test that post-shutdown callbacks are called."""
        called = []

        def post_shutdown():
            called.append("post")

        shutdown_manager.on_post_shutdown(post_shutdown)

        await shutdown_manager.shutdown()

        assert "post" in called

    def test_get_status(self, shutdown_manager):
        """Test getting shutdown status."""
        status = shutdown_manager.get_status()

        assert status["state"] == ShutdownState.RUNNING
        assert status["in_flight_count"] == 0
        assert status["registered_tasks"] == 0

    @pytest.mark.asyncio
    async def test_task_timeout(self, shutdown_manager):
        """Test that task timeout is enforced."""
        async def slow_task():
            await asyncio.sleep(10.0)

        shutdown_manager.register_task(
            "slow",
            slow_task,
            timeout_seconds=0.1,
            required=False,
        )

        result = await shutdown_manager.shutdown()

        assert result.task_results["slow"] == "timeout"
        assert result.tasks_skipped == 1


class TestShutdownHelpers:
    """Tests for shutdown helper functions."""

    @pytest.mark.asyncio
    async def test_stop_agents(self):
        """Test stop_agents helper."""
        mock_agent = AsyncMock()
        mock_agent.stop = AsyncMock()

        agents = {"agent-1": mock_agent}

        await stop_agents(agents)

        mock_agent.stop.assert_called_once()

    @pytest.mark.asyncio
    async def test_cancel_pending_tasks(self):
        """Test cancel_pending_tasks helper."""
        mock_queue = AsyncMock()
        mock_queue.cancel_all = AsyncMock()

        await cancel_pending_tasks(mock_queue)

        mock_queue.cancel_all.assert_called_once()

    @pytest.mark.asyncio
    async def test_close_connections(self):
        """Test close_connections helper."""
        mock_conn = AsyncMock()
        mock_conn.close = AsyncMock()

        await close_connections([mock_conn])

        mock_conn.close.assert_called_once()

    @pytest.mark.asyncio
    async def test_persist_state(self):
        """Test persist_state helper."""
        mock_db = AsyncMock()
        mock_db.save_shutdown_state = AsyncMock()

        state = {"key": "value"}

        await persist_state(mock_db, state)

        mock_db.save_shutdown_state.assert_called_once_with(state)


class TestSingletons:
    """Tests for singleton functions."""

    def test_get_retry_manager(self):
        """Test getting global retry manager."""
        manager = get_retry_manager()
        assert isinstance(manager, RetryManager)

    def test_set_retry_manager(self):
        """Test setting global retry manager."""
        custom = RetryManager()
        set_retry_manager(custom)

        assert get_retry_manager() is custom

    def test_get_shutdown_manager(self):
        """Test getting global shutdown manager."""
        manager = get_shutdown_manager()
        assert isinstance(manager, GracefulShutdown)

    def test_set_shutdown_manager(self):
        """Test setting global shutdown manager."""
        custom = GracefulShutdown()
        set_shutdown_manager(custom)

        assert get_shutdown_manager() is custom


# =============================================================================
# Integration Tests
# =============================================================================

class TestRateLimitRetryIntegration:
    """Integration tests for rate limiting and retry together."""

    @pytest.mark.asyncio
    async def test_rate_limit_with_retry(self):
        """Test that rate limit integrates with retry logic."""
        registry = RateLimiterRegistry()
        manager = RetryManager()

        call_count = 0

        async def api_call():
            nonlocal call_count
            call_count += 1

            limiter = registry.get("test")
            await limiter.acquire(tokens=100)

            if call_count < 2:
                await limiter.record_failure(is_rate_limit=True)
                raise Exception("rate limit exceeded")

            await limiter.record_success()
            return "success"

        manager.retry_config = RetryConfig(
            max_retries=3,
            initial_delay=0.01,
        )

        result = await manager.execute_with_retry(
            api_call,
            provider="test",
        )

        assert result == "success"
        assert call_count == 2


class TestShutdownIntegration:
    """Integration tests for shutdown with other components."""

    @pytest.mark.asyncio
    async def test_shutdown_with_rate_limiter(self):
        """Test shutdown cleans up rate limiters."""
        shutdown = GracefulShutdown()
        registry = RateLimiterRegistry()

        # Create some limiters
        registry.get("anthropic")
        registry.get("openai")

        cleanup_called = False

        async def cleanup_limiters():
            nonlocal cleanup_called
            cleanup_called = True

        shutdown.register_task("cleanup_limiters", cleanup_limiters)

        await shutdown.shutdown()

        assert cleanup_called

    @pytest.mark.asyncio
    async def test_full_shutdown_sequence(self):
        """Test a complete shutdown sequence."""
        shutdown = GracefulShutdown()
        sequence = []

        async def stop_agents_task():
            sequence.append("agents")

        async def cancel_tasks_task():
            sequence.append("tasks")

        async def close_connections_task():
            sequence.append("connections")

        async def persist_state_task():
            sequence.append("state")

        async def flush_logs_task():
            sequence.append("logs")

        shutdown.register_task("agents", stop_agents_task, priority=1)
        shutdown.register_task("tasks", cancel_tasks_task, priority=2)
        shutdown.register_task("connections", close_connections_task, priority=3)
        shutdown.register_task("state", persist_state_task, priority=4)
        shutdown.register_task("logging", flush_logs_task, priority=5)

        result = await shutdown.shutdown()

        assert result.success
        assert sequence == ["agents", "tasks", "connections", "state", "logs"]
