"""
Rate Limiter - Request rate limiting with exponential backoff.

Implements:
- Token bucket rate limiting
- Exponential backoff for retries
- Per-provider rate limits
- Sliding window tracking
"""

from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, Callable
import asyncio
import time
import logging
import random

logger = logging.getLogger(__name__)


@dataclass
class RateLimitConfig:
    """Configuration for rate limiting."""

    # Requests per minute
    requests_per_minute: int = 60

    # Requests per second (for burst control)
    requests_per_second: int = 10

    # Maximum concurrent requests
    concurrent_requests: int = 10

    # Token limits (for LLM APIs)
    tokens_per_minute: int = 100_000
    tokens_per_day: int = 10_000_000

    # Backoff settings
    initial_backoff_seconds: float = 1.0
    max_backoff_seconds: float = 60.0
    max_backoff: float = 60.0  # Alias for tests
    backoff_multiplier: float = 2.0
    max_retries: int = 5

    # Jitter to prevent thundering herd
    jitter_factor: float = 0.1

    def __post_init__(self):
        """Sync alias with main field."""
        self.max_backoff = self.max_backoff_seconds


# Default rate limits for known providers
DEFAULT_RATE_LIMITS: Dict[str, RateLimitConfig] = {
    "anthropic": RateLimitConfig(
        requests_per_minute=60,
        tokens_per_minute=100_000,
        tokens_per_day=10_000_000,
    ),
    "openai": RateLimitConfig(
        requests_per_minute=60,
        tokens_per_minute=90_000,
        tokens_per_day=10_000_000,
    ),
    "google": RateLimitConfig(
        requests_per_minute=60,
        tokens_per_minute=120_000,  # Gemini has high limits
        tokens_per_day=50_000_000,
    ),
    "default": RateLimitConfig(
        requests_per_minute=30,
        tokens_per_minute=50_000,
    ),
}


@dataclass
class RateLimitState:
    """Current state of rate limiting."""

    # Request tracking
    request_timestamps: list = field(default_factory=list)
    tokens_this_minute: int = 0
    tokens_today: int = 0

    # Tracking for tests
    current_requests: int = 0
    current_tokens: int = 0
    success_count: int = 0
    error_count: int = 0

    # Backoff state
    consecutive_failures: int = 0
    last_failure_time: Optional[datetime] = None
    current_backoff: float = 0.0

    # Daily reset
    day_start: datetime = field(default_factory=lambda: datetime.now().replace(
        hour=0, minute=0, second=0, microsecond=0
    ))


class RateLimiter:
    """
    Rate limiter with exponential backoff.

    Implements token bucket algorithm with sliding window
    for request tracking.
    """

    def __init__(
        self,
        config: Optional[RateLimitConfig] = None,
        provider: str = "default",
    ):
        """
        Initialize rate limiter.

        Args:
            config: Custom rate limit configuration
            provider: Provider name for default config lookup
        """
        self.config = config or DEFAULT_RATE_LIMITS.get(
            provider, DEFAULT_RATE_LIMITS["default"]
        )
        self.provider = provider
        self._state = RateLimitState()
        self._lock = asyncio.Lock()

    @property
    def state(self) -> RateLimitState:
        """Expose state for testing and inspection."""
        return self._state

    def _clean_old_requests(self) -> None:
        """Remove request timestamps older than 1 minute."""
        cutoff = datetime.now() - timedelta(minutes=1)
        self._state.request_timestamps = [
            ts for ts in self._state.request_timestamps
            if ts > cutoff
        ]

    def _check_daily_reset(self) -> None:
        """Reset daily counters if needed."""
        today_start = datetime.now().replace(
            hour=0, minute=0, second=0, microsecond=0
        )
        if today_start > self._state.day_start:
            self._state.day_start = today_start
            self._state.tokens_today = 0
            logger.info(f"Daily rate limit counters reset for {self.provider}")

    async def acquire(self, tokens: int = 0) -> float:
        """
        Acquire permission to make a request.

        Args:
            tokens: Expected tokens for the request

        Returns:
            Wait time in seconds before request can proceed
        """
        async with self._lock:
            self._clean_old_requests()
            self._check_daily_reset()

            wait_time = 0.0

            # Check concurrent request limit
            if self._state.current_requests >= self.config.concurrent_requests:
                wait_time = max(wait_time, 0.1)  # Small wait for concurrent limit

            # Check requests per minute
            if len(self._state.request_timestamps) >= self.config.requests_per_minute:
                oldest = min(self._state.request_timestamps)
                wait_time = max(
                    wait_time,
                    (oldest + timedelta(minutes=1) - datetime.now()).total_seconds()
                )

            # Check requests per second (burst control)
            recent = [
                ts for ts in self._state.request_timestamps
                if ts > datetime.now() - timedelta(seconds=1)
            ]
            if len(recent) >= self.config.requests_per_second:
                wait_time = max(wait_time, 1.0 / self.config.requests_per_second)

            # Check tokens per minute
            if tokens > 0:
                # Clean token counter
                minute_ago = datetime.now() - timedelta(minutes=1)
                if self._state.tokens_this_minute > 0:
                    # Simple decay - in production would track per-timestamp
                    self._state.tokens_this_minute = max(
                        0,
                        self._state.tokens_this_minute - (self.config.tokens_per_minute // 60)
                    )

                if self._state.tokens_this_minute + tokens > self.config.tokens_per_minute:
                    token_wait = (
                        (self._state.tokens_this_minute + tokens - self.config.tokens_per_minute)
                        / self.config.tokens_per_minute * 60
                    )
                    wait_time = max(wait_time, token_wait)

                # Check daily token limit
                if self._state.tokens_today + tokens > self.config.tokens_per_day:
                    # Need to wait until next day - return large wait time
                    logger.warning(f"Daily token limit reached for {self.provider}")
                    wait_time = max(wait_time, 3600.0)  # Wait 1 hour as indicator

            # Apply current backoff if any
            if self._state.current_backoff > 0:
                wait_time = max(wait_time, self._state.current_backoff)

            # Track all acquire calls
            self._state.current_requests += 1
            self._state.current_tokens += tokens

            return wait_time

    async def record_request(self, tokens: int = 0) -> None:
        """
        Record a request was made.

        Args:
            tokens: Tokens used in the request
        """
        async with self._lock:
            self._state.request_timestamps.append(datetime.now())
            self._state.tokens_this_minute += tokens
            self._state.tokens_today += tokens

    async def record_success(self) -> None:
        """Record successful request - reset backoff."""
        async with self._lock:
            self._state.consecutive_failures = 0
            self._state.success_count += 1
            # Reduce backoff on success
            self._state.current_backoff = max(0.0, self._state.current_backoff * 0.5)
            # Free up a concurrent slot
            if self._state.current_requests > 0:
                self._state.current_requests -= 1

    async def record_failure(self, is_rate_limit: bool = False) -> float:
        """
        Record failed request - apply backoff.

        Args:
            is_rate_limit: Whether failure was due to rate limiting

        Returns:
            Backoff time to wait before retry
        """
        async with self._lock:
            self._state.consecutive_failures += 1
            self._state.error_count += 1
            self._state.last_failure_time = datetime.now()

            # Free up a concurrent slot
            if self._state.current_requests > 0:
                self._state.current_requests -= 1

            # Calculate exponential backoff
            backoff = self.config.initial_backoff_seconds * (
                self.config.backoff_multiplier ** (self._state.consecutive_failures - 1)
            )
            backoff = min(backoff, self.config.max_backoff_seconds)

            # Add jitter
            jitter = random.uniform(
                -self.config.jitter_factor * backoff,
                self.config.jitter_factor * backoff
            )
            backoff += jitter

            # Double backoff for rate limit errors
            if is_rate_limit:
                backoff *= 2
                # Cap at max_backoff
                backoff = min(backoff, self.config.max_backoff_seconds)

            self._state.current_backoff = backoff

            logger.warning(
                f"Rate limiter backoff for {self.provider}: {backoff:.2f}s "
                f"(failures: {self._state.consecutive_failures})"
            )

            return backoff

    def should_retry(self) -> bool:
        """Check if we should retry after failure."""
        return self._state.consecutive_failures < self.config.max_retries

    def get_stats(self) -> Dict[str, Any]:
        """Get current rate limiter statistics."""
        self._clean_old_requests()
        return {
            "provider": self.provider,
            "requests_last_minute": len(self._state.request_timestamps),
            "tokens_last_minute": self._state.tokens_this_minute,
            "tokens_today": self._state.tokens_today,
            "consecutive_failures": self._state.consecutive_failures,
            "current_backoff": self._state.current_backoff,
            "limits": {
                "requests_per_minute": self.config.requests_per_minute,
                "tokens_per_minute": self.config.tokens_per_minute,
                "tokens_per_day": self.config.tokens_per_day,
            },
        }


class RateLimiterRegistry:
    """Registry of rate limiters for different providers."""

    def __init__(self):
        """Initialize registry."""
        self._limiters: Dict[str, RateLimiter] = {}
        self._lock = asyncio.Lock()

    def get(self, provider: str) -> RateLimiter:
        """
        Get or create rate limiter for a provider (synchronous).

        Args:
            provider: Provider name

        Returns:
            RateLimiter for the provider
        """
        if provider not in self._limiters:
            config = DEFAULT_RATE_LIMITS.get(
                provider, DEFAULT_RATE_LIMITS["default"]
            )
            self._limiters[provider] = RateLimiter(
                config=config,
                provider=provider,
            )
        return self._limiters[provider]

    async def get_limiter(self, provider: str) -> RateLimiter:
        """Get or create rate limiter for a provider (async)."""
        async with self._lock:
            return self.get(provider)

    def set_limiter(self, provider: str, limiter: RateLimiter) -> None:
        """Set a custom rate limiter for a provider."""
        self._limiters[provider] = limiter

    def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
        """Get stats for all rate limiters."""
        return {
            provider: limiter.get_stats()
            for provider, limiter in self._limiters.items()
        }


# Global registry
_rate_limiter_registry: Optional[RateLimiterRegistry] = None


def get_rate_limiter_registry() -> RateLimiterRegistry:
    """Get or create the global rate limiter registry."""
    global _rate_limiter_registry
    if _rate_limiter_registry is None:
        _rate_limiter_registry = RateLimiterRegistry()
    return _rate_limiter_registry


async def with_rate_limit(
    provider: str,
    tokens: int = 0,
) -> float:
    """
    Convenience function to acquire rate limit permission.

    Args:
        provider: Provider name
        tokens: Expected tokens

    Returns:
        Wait time before proceeding
    """
    registry = get_rate_limiter_registry()
    limiter = await registry.get_limiter(provider)
    return await limiter.acquire(tokens)
