"""
Agent Budget - Per-agent cost and usage limits.

Implements:
- Daily token limits per agent
- Daily cost limits per agent
- Budget enforcement before task routing
- Usage tracking with daily reset
"""

from dataclasses import dataclass, field
from datetime import datetime, date
from enum import Enum
from typing import Optional, Dict, Any
import logging

logger = logging.getLogger(__name__)


class BudgetStatus(Enum):
    """Status of budget check."""
    WITHIN_BUDGET = "within_budget"
    WARNING = "warning"  # Approaching limit (>80%)
    EXCEEDED = "exceeded"
    PAUSED = "paused"  # Manually paused


@dataclass
class AgentBudget:
    """Budget configuration for a single agent."""

    agent_id: str

    # Token limits
    daily_input_tokens: int = 1_000_000  # 1M input tokens/day
    daily_output_tokens: int = 200_000   # 200K output tokens/day

    # Cost limits (in USD)
    daily_cost_limit: float = 50.0  # $50/day default
    monthly_cost_limit: float = 1000.0  # $1000/month default

    # Per-request limits
    max_tokens_per_request: int = 100_000  # Max tokens for single request
    max_cost_per_request: float = 5.0  # Max cost for single request

    # Warning thresholds (percentage)
    warning_threshold: float = 0.8  # Warn at 80%

    # State
    is_paused: bool = False
    pause_reason: Optional[str] = None

    # Metadata
    created_at: datetime = field(default_factory=datetime.now)
    updated_at: datetime = field(default_factory=datetime.now)


@dataclass
class DailyUsage:
    """Daily usage tracking for an agent."""

    agent_id: str
    date: date

    # Token usage
    input_tokens: int = 0
    output_tokens: int = 0

    # Cost tracking (USD)
    total_cost: float = 0.0

    # Request counts
    request_count: int = 0
    successful_requests: int = 0
    failed_requests: int = 0
    rejected_requests: int = 0  # Rejected due to budget

    # Peak tracking
    peak_tokens_request: int = 0
    peak_cost_request: float = 0.0

    # Timestamps
    first_request_at: Optional[datetime] = None
    last_request_at: Optional[datetime] = None


@dataclass
class BudgetCheckResult:
    """Result of a budget check."""

    status: BudgetStatus
    allowed: bool

    # Current usage
    input_tokens_used: int = 0
    output_tokens_used: int = 0
    cost_used: float = 0.0

    # Remaining budget
    input_tokens_remaining: int = 0
    output_tokens_remaining: int = 0
    cost_remaining: float = 0.0

    # Percentages
    input_token_percentage: float = 0.0
    output_token_percentage: float = 0.0
    cost_percentage: float = 0.0

    # Reason if not allowed
    reason: Optional[str] = None

    # Estimated request cost (for pre-check)
    estimated_cost: float = 0.0


# Default budgets for known agent types
DEFAULT_BUDGETS: Dict[str, AgentBudget] = {
    "claude-code": AgentBudget(
        agent_id="claude-code",
        daily_input_tokens=2_000_000,  # Claude Code does a lot of file reading
        daily_output_tokens=500_000,
        daily_cost_limit=100.0,
        monthly_cost_limit=2000.0,
    ),
    "gemini-cli": AgentBudget(
        agent_id="gemini-cli",
        daily_input_tokens=5_000_000,  # Gemini has 1M context, often used for large docs
        daily_output_tokens=500_000,
        daily_cost_limit=50.0,
        monthly_cost_limit=1000.0,
    ),
    "codex-cli": AgentBudget(
        agent_id="codex-cli",
        daily_input_tokens=1_000_000,
        daily_output_tokens=200_000,
        daily_cost_limit=30.0,
        monthly_cost_limit=600.0,
    ),
    "claude-sdk": AgentBudget(
        agent_id="claude-sdk",
        daily_input_tokens=2_000_000,
        daily_output_tokens=500_000,
        daily_cost_limit=100.0,
        monthly_cost_limit=2000.0,
    ),
    "openai-agents": AgentBudget(
        agent_id="openai-agents",
        daily_input_tokens=1_500_000,
        daily_output_tokens=300_000,
        daily_cost_limit=75.0,
        monthly_cost_limit=1500.0,
    ),
}

# Approximate costs per 1K tokens (for estimation)
TOKEN_COSTS = {
    "claude-code": {"input": 0.003, "output": 0.015},  # Claude 3.5 Sonnet pricing
    "claude-sdk": {"input": 0.003, "output": 0.015},
    "gemini-cli": {"input": 0.00035, "output": 0.00105},  # Gemini 1.5 Pro pricing
    "codex-cli": {"input": 0.002, "output": 0.008},  # Approximate
    "openai-agents": {"input": 0.005, "output": 0.015},  # GPT-4o pricing
    "default": {"input": 0.003, "output": 0.015},
}


class BudgetEnforcer:
    """Enforces budget limits for agents."""

    def __init__(
        self,
        db: Any,
        budgets: Optional[Dict[str, AgentBudget]] = None,
    ):
        """
        Initialize budget enforcer.

        Args:
            db: Database instance for usage tracking
            budgets: Custom budget configurations (uses defaults if not provided)
        """
        self.db = db
        self._budgets = budgets or DEFAULT_BUDGETS.copy()
        self._daily_usage: Dict[str, DailyUsage] = {}

    def get_budget(self, agent_id: str) -> AgentBudget:
        """Get budget for an agent, creating default if needed."""
        if agent_id not in self._budgets:
            # Create default budget for unknown agents
            self._budgets[agent_id] = AgentBudget(agent_id=agent_id)
        return self._budgets[agent_id]

    def set_budget(self, agent_id: str, budget: AgentBudget) -> None:
        """Set or update budget for an agent."""
        self._budgets[agent_id] = budget
        budget.updated_at = datetime.now()

    def _get_daily_usage(self, agent_id: str) -> DailyUsage:
        """Get or create daily usage record for today."""
        today = date.today()
        key = f"{agent_id}:{today.isoformat()}"

        if key not in self._daily_usage:
            # Try to load from database
            usage = self._load_usage_from_db(agent_id, today)
            if usage:
                self._daily_usage[key] = usage
            else:
                self._daily_usage[key] = DailyUsage(
                    agent_id=agent_id,
                    date=today,
                )

        return self._daily_usage[key]

    def _load_usage_from_db(self, agent_id: str, day: date) -> Optional[DailyUsage]:
        """Load usage record from database."""
        if not hasattr(self.db, 'get_daily_usage'):
            return None

        record = self.db.get_daily_usage(agent_id, day.isoformat())
        if not record:
            return None

        return DailyUsage(
            agent_id=agent_id,
            date=day,
            input_tokens=record.get('input_tokens', 0),
            output_tokens=record.get('output_tokens', 0),
            total_cost=record.get('total_cost', 0.0),
            request_count=record.get('request_count', 0),
            successful_requests=record.get('successful_requests', 0),
            failed_requests=record.get('failed_requests', 0),
            rejected_requests=record.get('rejected_requests', 0),
        )

    def _save_usage_to_db(self, usage: DailyUsage) -> None:
        """Save usage record to database."""
        if not hasattr(self.db, 'upsert_daily_usage'):
            return

        self.db.upsert_daily_usage(
            agent_id=usage.agent_id,
            date=usage.date.isoformat(),
            input_tokens=usage.input_tokens,
            output_tokens=usage.output_tokens,
            total_cost=usage.total_cost,
            request_count=usage.request_count,
            successful_requests=usage.successful_requests,
            failed_requests=usage.failed_requests,
            rejected_requests=usage.rejected_requests,
        )

    def check_budget(
        self,
        agent_id: str,
        estimated_input_tokens: int = 0,
        estimated_output_tokens: int = 0,
    ) -> BudgetCheckResult:
        """
        Check if agent has budget for a request.

        Args:
            agent_id: Agent to check
            estimated_input_tokens: Estimated input tokens for request
            estimated_output_tokens: Estimated output tokens for request

        Returns:
            BudgetCheckResult with status and details
        """
        budget = self.get_budget(agent_id)
        usage = self._get_daily_usage(agent_id)

        # Check if paused
        if budget.is_paused:
            return BudgetCheckResult(
                status=BudgetStatus.PAUSED,
                allowed=False,
                reason=budget.pause_reason or "Agent budget is paused",
                input_tokens_used=usage.input_tokens,
                output_tokens_used=usage.output_tokens,
                cost_used=usage.total_cost,
            )

        # Calculate remaining budget
        input_remaining = budget.daily_input_tokens - usage.input_tokens
        output_remaining = budget.daily_output_tokens - usage.output_tokens
        cost_remaining = budget.daily_cost_limit - usage.total_cost

        # Calculate percentages
        input_pct = usage.input_tokens / budget.daily_input_tokens if budget.daily_input_tokens > 0 else 0
        output_pct = usage.output_tokens / budget.daily_output_tokens if budget.daily_output_tokens > 0 else 0
        cost_pct = usage.total_cost / budget.daily_cost_limit if budget.daily_cost_limit > 0 else 0

        # Estimate cost for this request
        token_costs = TOKEN_COSTS.get(agent_id, TOKEN_COSTS["default"])
        estimated_cost = (
            (estimated_input_tokens / 1000) * token_costs["input"] +
            (estimated_output_tokens / 1000) * token_costs["output"]
        )

        # Check per-request limits
        if estimated_input_tokens + estimated_output_tokens > budget.max_tokens_per_request:
            return BudgetCheckResult(
                status=BudgetStatus.EXCEEDED,
                allowed=False,
                reason=f"Request exceeds max tokens per request ({budget.max_tokens_per_request})",
                input_tokens_used=usage.input_tokens,
                output_tokens_used=usage.output_tokens,
                cost_used=usage.total_cost,
                input_tokens_remaining=input_remaining,
                output_tokens_remaining=output_remaining,
                cost_remaining=cost_remaining,
                estimated_cost=estimated_cost,
            )

        if estimated_cost > budget.max_cost_per_request:
            return BudgetCheckResult(
                status=BudgetStatus.EXCEEDED,
                allowed=False,
                reason=f"Request exceeds max cost per request (${budget.max_cost_per_request:.2f})",
                input_tokens_used=usage.input_tokens,
                output_tokens_used=usage.output_tokens,
                cost_used=usage.total_cost,
                input_tokens_remaining=input_remaining,
                output_tokens_remaining=output_remaining,
                cost_remaining=cost_remaining,
                estimated_cost=estimated_cost,
            )

        # Check daily limits
        exceeded_reasons = []

        if usage.input_tokens + estimated_input_tokens > budget.daily_input_tokens:
            exceeded_reasons.append("daily input token limit")

        if usage.output_tokens + estimated_output_tokens > budget.daily_output_tokens:
            exceeded_reasons.append("daily output token limit")

        if usage.total_cost + estimated_cost > budget.daily_cost_limit:
            exceeded_reasons.append("daily cost limit")

        if exceeded_reasons:
            return BudgetCheckResult(
                status=BudgetStatus.EXCEEDED,
                allowed=False,
                reason=f"Would exceed {', '.join(exceeded_reasons)}",
                input_tokens_used=usage.input_tokens,
                output_tokens_used=usage.output_tokens,
                cost_used=usage.total_cost,
                input_tokens_remaining=input_remaining,
                output_tokens_remaining=output_remaining,
                cost_remaining=cost_remaining,
                input_token_percentage=input_pct,
                output_token_percentage=output_pct,
                cost_percentage=cost_pct,
                estimated_cost=estimated_cost,
            )

        # Check warning threshold
        max_pct = max(input_pct, output_pct, cost_pct)
        status = BudgetStatus.WARNING if max_pct >= budget.warning_threshold else BudgetStatus.WITHIN_BUDGET

        return BudgetCheckResult(
            status=status,
            allowed=True,
            input_tokens_used=usage.input_tokens,
            output_tokens_used=usage.output_tokens,
            cost_used=usage.total_cost,
            input_tokens_remaining=input_remaining,
            output_tokens_remaining=output_remaining,
            cost_remaining=cost_remaining,
            input_token_percentage=input_pct,
            output_token_percentage=output_pct,
            cost_percentage=cost_pct,
            estimated_cost=estimated_cost,
        )

    def record_usage(
        self,
        agent_id: str,
        input_tokens: int,
        output_tokens: int,
        cost: Optional[float] = None,
        success: bool = True,
    ) -> None:
        """
        Record usage for an agent request.

        Args:
            agent_id: Agent that made the request
            input_tokens: Input tokens used
            output_tokens: Output tokens used
            cost: Actual cost (calculated if not provided)
            success: Whether request succeeded
        """
        usage = self._get_daily_usage(agent_id)
        now = datetime.now()

        # Update token counts
        usage.input_tokens += input_tokens
        usage.output_tokens += output_tokens

        # Calculate cost if not provided
        if cost is None:
            token_costs = TOKEN_COSTS.get(agent_id, TOKEN_COSTS["default"])
            cost = (
                (input_tokens / 1000) * token_costs["input"] +
                (output_tokens / 1000) * token_costs["output"]
            )

        usage.total_cost += cost

        # Update request counts
        usage.request_count += 1
        if success:
            usage.successful_requests += 1
        else:
            usage.failed_requests += 1

        # Track peak usage
        total_tokens = input_tokens + output_tokens
        if total_tokens > usage.peak_tokens_request:
            usage.peak_tokens_request = total_tokens
        if cost > usage.peak_cost_request:
            usage.peak_cost_request = cost

        # Update timestamps
        if usage.first_request_at is None:
            usage.first_request_at = now
        usage.last_request_at = now

        # Persist to database
        self._save_usage_to_db(usage)

        # Log if approaching limits
        budget = self.get_budget(agent_id)
        input_pct = usage.input_tokens / budget.daily_input_tokens
        output_pct = usage.output_tokens / budget.daily_output_tokens
        cost_pct = usage.total_cost / budget.daily_cost_limit

        max_pct = max(input_pct, output_pct, cost_pct)
        if max_pct >= budget.warning_threshold:
            logger.warning(
                f"Agent {agent_id} approaching budget limit: "
                f"input={input_pct:.1%}, output={output_pct:.1%}, cost={cost_pct:.1%}"
            )

    def record_rejection(self, agent_id: str) -> None:
        """Record a budget rejection for an agent."""
        usage = self._get_daily_usage(agent_id)
        usage.rejected_requests += 1
        self._save_usage_to_db(usage)

    def pause_agent(self, agent_id: str, reason: str) -> None:
        """Pause an agent's budget (prevents new requests)."""
        budget = self.get_budget(agent_id)
        budget.is_paused = True
        budget.pause_reason = reason
        budget.updated_at = datetime.now()
        logger.info(f"Paused agent {agent_id} budget: {reason}")

    def resume_agent(self, agent_id: str) -> None:
        """Resume an agent's budget."""
        budget = self.get_budget(agent_id)
        budget.is_paused = False
        budget.pause_reason = None
        budget.updated_at = datetime.now()
        logger.info(f"Resumed agent {agent_id} budget")

    def get_usage_summary(self, agent_id: str) -> Dict[str, Any]:
        """Get usage summary for an agent."""
        budget = self.get_budget(agent_id)
        usage = self._get_daily_usage(agent_id)

        return {
            "agent_id": agent_id,
            "date": usage.date.isoformat(),
            "budget": {
                "daily_input_tokens": budget.daily_input_tokens,
                "daily_output_tokens": budget.daily_output_tokens,
                "daily_cost_limit": budget.daily_cost_limit,
            },
            "usage": {
                "input_tokens": usage.input_tokens,
                "output_tokens": usage.output_tokens,
                "total_cost": usage.total_cost,
                "request_count": usage.request_count,
            },
            "remaining": {
                "input_tokens": budget.daily_input_tokens - usage.input_tokens,
                "output_tokens": budget.daily_output_tokens - usage.output_tokens,
                "cost": budget.daily_cost_limit - usage.total_cost,
            },
            "percentages": {
                "input_tokens": usage.input_tokens / budget.daily_input_tokens if budget.daily_input_tokens > 0 else 0,
                "output_tokens": usage.output_tokens / budget.daily_output_tokens if budget.daily_output_tokens > 0 else 0,
                "cost": usage.total_cost / budget.daily_cost_limit if budget.daily_cost_limit > 0 else 0,
            },
            "status": "paused" if budget.is_paused else "active",
        }

    def reset_daily_usage(self, agent_id: str) -> None:
        """Manually reset daily usage for an agent (for testing)."""
        today = date.today()
        key = f"{agent_id}:{today.isoformat()}"
        self._daily_usage[key] = DailyUsage(
            agent_id=agent_id,
            date=today,
        )
        logger.info(f"Reset daily usage for agent {agent_id}")


# Singleton instance
_budget_enforcer: Optional[BudgetEnforcer] = None


def get_budget_enforcer() -> Optional[BudgetEnforcer]:
    """Get the global budget enforcer instance."""
    return _budget_enforcer


def set_budget_enforcer(enforcer: BudgetEnforcer) -> None:
    """Set the global budget enforcer instance."""
    global _budget_enforcer
    _budget_enforcer = enforcer


def has_budget(agent_id: str, estimated_tokens: int = 0) -> bool:
    """Quick check if agent has budget available."""
    enforcer = get_budget_enforcer()
    if enforcer is None:
        return True  # No enforcement configured

    result = enforcer.check_budget(
        agent_id,
        estimated_input_tokens=estimated_tokens,
        estimated_output_tokens=int(estimated_tokens * 0.2),  # Estimate output as 20% of input
    )
    return result.allowed
