"""
Token Audit Integration - Local usage ledger with MCP query support.

Token Audit is a local ledger that tracks LLM API usage and costs.
This module provides:
- Integration with Token Audit MCP server
- Usage querying and aggregation
- Budget enforcement integration
- Cost tracking and reporting
"""

from dataclasses import dataclass, field
from datetime import datetime, date, timedelta
from typing import Optional, Dict, List, Any
import json
import logging

logger = logging.getLogger(__name__)


@dataclass
class UsageEntry:
    """Single usage entry from Token Audit."""

    timestamp: datetime
    provider: str  # anthropic, openai, google, etc.
    model: str
    input_tokens: int
    output_tokens: int
    cost_usd: float

    # Optional fields
    agent_id: Optional[str] = None
    task_id: Optional[str] = None
    session_id: Optional[str] = None
    request_id: Optional[str] = None

    # Metadata
    tags: Dict[str, str] = field(default_factory=dict)


@dataclass
class UsageSummary:
    """Aggregated usage summary."""

    period_start: datetime
    period_end: datetime

    # Totals
    total_input_tokens: int = 0
    total_output_tokens: int = 0
    total_cost_usd: float = 0.0
    total_requests: int = 0

    # Per-provider breakdown
    by_provider: Dict[str, Dict[str, Any]] = field(default_factory=dict)

    # Per-model breakdown
    by_model: Dict[str, Dict[str, Any]] = field(default_factory=dict)

    # Per-agent breakdown
    by_agent: Dict[str, Dict[str, Any]] = field(default_factory=dict)


@dataclass
class CostAlert:
    """Alert for cost threshold."""

    threshold_type: str  # daily, hourly, per_request
    threshold_value: float
    current_value: float
    exceeded: bool
    message: str
    timestamp: datetime = field(default_factory=datetime.now)


class TokenAuditClient:
    """
    Client for Token Audit integration.

    Token Audit provides a local ledger for tracking LLM API usage.
    This client can:
    - Query usage via MCP
    - Aggregate usage data
    - Check cost thresholds
    - Generate reports
    """

    def __init__(
        self,
        mcp_client: Optional[Any] = None,
        db: Optional[Any] = None,
        config: Optional[Dict[str, Any]] = None,
    ):
        """
        Initialize Token Audit client.

        Args:
            mcp_client: MCP client for querying Token Audit server
            db: Database for caching/persistence
            config: Configuration options
        """
        self.mcp_client = mcp_client
        self.db = db
        self.config = config or {}

        # Default cost thresholds
        self._daily_threshold = self.config.get("daily_cost_threshold", 100.0)
        self._hourly_threshold = self.config.get("hourly_cost_threshold", 20.0)
        self._request_threshold = self.config.get("request_cost_threshold", 5.0)

        # Cache for usage data
        self._usage_cache: Dict[str, UsageSummary] = {}
        self._cache_ttl = self.config.get("cache_ttl_seconds", 300)  # 5 min

        # Alert callbacks
        self._alert_callbacks: List[callable] = []

    async def query_usage(
        self,
        start_time: Optional[datetime] = None,
        end_time: Optional[datetime] = None,
        provider: Optional[str] = None,
        model: Optional[str] = None,
        agent_id: Optional[str] = None,
        limit: int = 1000,
    ) -> List[UsageEntry]:
        """
        Query usage entries from Token Audit.

        Args:
            start_time: Start of time range
            end_time: End of time range
            provider: Filter by provider
            model: Filter by model
            agent_id: Filter by agent
            limit: Maximum entries to return

        Returns:
            List of UsageEntry objects
        """
        if self.mcp_client is None:
            return self._query_from_db(start_time, end_time, provider, model, agent_id, limit)

        try:
            # Query via MCP
            params = {
                "start_time": start_time.isoformat() if start_time else None,
                "end_time": end_time.isoformat() if end_time else None,
                "provider": provider,
                "model": model,
                "agent_id": agent_id,
                "limit": limit,
            }

            # Remove None values
            params = {k: v for k, v in params.items() if v is not None}

            result = await self.mcp_client.call_tool(
                server="token-audit",
                tool="query_usage",
                arguments=params,
            )

            return self._parse_usage_response(result)

        except Exception as e:
            logger.error(f"Failed to query Token Audit via MCP: {e}")
            # Fall back to database
            return self._query_from_db(start_time, end_time, provider, model, agent_id, limit)

    def _query_from_db(
        self,
        start_time: Optional[datetime],
        end_time: Optional[datetime],
        provider: Optional[str],
        model: Optional[str],
        agent_id: Optional[str],
        limit: int,
    ) -> List[UsageEntry]:
        """Query usage from local database."""
        if self.db is None or not hasattr(self.db, 'query_usage'):
            return []

        records = self.db.query_usage(
            start_time=start_time,
            end_time=end_time,
            provider=provider,
            model=model,
            agent_id=agent_id,
            limit=limit,
        )

        return [
            UsageEntry(
                timestamp=datetime.fromisoformat(r['timestamp']),
                provider=r['provider'],
                model=r['model'],
                input_tokens=r['input_tokens'],
                output_tokens=r['output_tokens'],
                cost_usd=r['cost_usd'],
                agent_id=r.get('agent_id'),
                task_id=r.get('task_id'),
            )
            for r in records
        ]

    def _parse_usage_response(self, response: Any) -> List[UsageEntry]:
        """Parse MCP response into UsageEntry objects."""
        entries = []

        if isinstance(response, dict):
            data = response.get('entries', [])
        elif isinstance(response, list):
            data = response
        else:
            return entries

        for item in data:
            try:
                entries.append(UsageEntry(
                    timestamp=datetime.fromisoformat(item['timestamp']),
                    provider=item.get('provider', 'unknown'),
                    model=item.get('model', 'unknown'),
                    input_tokens=item.get('input_tokens', 0),
                    output_tokens=item.get('output_tokens', 0),
                    cost_usd=item.get('cost_usd', 0.0),
                    agent_id=item.get('agent_id'),
                    task_id=item.get('task_id'),
                    session_id=item.get('session_id'),
                    request_id=item.get('request_id'),
                    tags=item.get('tags', {}),
                ))
            except (KeyError, ValueError) as e:
                logger.warning(f"Failed to parse usage entry: {e}")
                continue

        return entries

    async def get_daily_summary(
        self,
        day: Optional[date] = None,
        agent_id: Optional[str] = None,
    ) -> UsageSummary:
        """
        Get usage summary for a day.

        Args:
            day: Date to summarize (default: today)
            agent_id: Filter by agent

        Returns:
            UsageSummary for the day
        """
        if day is None:
            day = date.today()

        # Check cache
        cache_key = f"daily:{day.isoformat()}:{agent_id or 'all'}"
        if cache_key in self._usage_cache:
            cached = self._usage_cache[cache_key]
            # Simple TTL check (would need timestamp in real impl)
            return cached

        start_time = datetime.combine(day, datetime.min.time())
        end_time = datetime.combine(day, datetime.max.time())

        entries = await self.query_usage(
            start_time=start_time,
            end_time=end_time,
            agent_id=agent_id,
        )

        summary = self._aggregate_entries(entries, start_time, end_time)

        # Cache the result
        self._usage_cache[cache_key] = summary

        return summary

    async def get_hourly_summary(
        self,
        hour: Optional[datetime] = None,
        agent_id: Optional[str] = None,
    ) -> UsageSummary:
        """
        Get usage summary for an hour.

        Args:
            hour: Hour to summarize (default: current hour)
            agent_id: Filter by agent

        Returns:
            UsageSummary for the hour
        """
        if hour is None:
            hour = datetime.now().replace(minute=0, second=0, microsecond=0)

        start_time = hour
        end_time = hour + timedelta(hours=1)

        entries = await self.query_usage(
            start_time=start_time,
            end_time=end_time,
            agent_id=agent_id,
        )

        return self._aggregate_entries(entries, start_time, end_time)

    def _aggregate_entries(
        self,
        entries: List[UsageEntry],
        start_time: datetime,
        end_time: datetime,
    ) -> UsageSummary:
        """Aggregate usage entries into a summary."""
        summary = UsageSummary(
            period_start=start_time,
            period_end=end_time,
        )

        for entry in entries:
            # Update totals
            summary.total_input_tokens += entry.input_tokens
            summary.total_output_tokens += entry.output_tokens
            summary.total_cost_usd += entry.cost_usd
            summary.total_requests += 1

            # By provider
            if entry.provider not in summary.by_provider:
                summary.by_provider[entry.provider] = {
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "cost_usd": 0.0,
                    "requests": 0,
                }
            summary.by_provider[entry.provider]["input_tokens"] += entry.input_tokens
            summary.by_provider[entry.provider]["output_tokens"] += entry.output_tokens
            summary.by_provider[entry.provider]["cost_usd"] += entry.cost_usd
            summary.by_provider[entry.provider]["requests"] += 1

            # By model
            if entry.model not in summary.by_model:
                summary.by_model[entry.model] = {
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "cost_usd": 0.0,
                    "requests": 0,
                }
            summary.by_model[entry.model]["input_tokens"] += entry.input_tokens
            summary.by_model[entry.model]["output_tokens"] += entry.output_tokens
            summary.by_model[entry.model]["cost_usd"] += entry.cost_usd
            summary.by_model[entry.model]["requests"] += 1

            # By agent
            if entry.agent_id:
                if entry.agent_id not in summary.by_agent:
                    summary.by_agent[entry.agent_id] = {
                        "input_tokens": 0,
                        "output_tokens": 0,
                        "cost_usd": 0.0,
                        "requests": 0,
                    }
                summary.by_agent[entry.agent_id]["input_tokens"] += entry.input_tokens
                summary.by_agent[entry.agent_id]["output_tokens"] += entry.output_tokens
                summary.by_agent[entry.agent_id]["cost_usd"] += entry.cost_usd
                summary.by_agent[entry.agent_id]["requests"] += 1

        return summary

    async def record_usage(
        self,
        provider: str,
        model: str,
        input_tokens: int,
        output_tokens: int,
        cost_usd: float,
        agent_id: Optional[str] = None,
        task_id: Optional[str] = None,
        tags: Optional[Dict[str, str]] = None,
    ) -> bool:
        """
        Record usage entry to Token Audit.

        Args:
            provider: LLM provider name
            model: Model name
            input_tokens: Input token count
            output_tokens: Output token count
            cost_usd: Cost in USD
            agent_id: Agent that made the request
            task_id: Associated task
            tags: Additional tags

        Returns:
            True if recorded successfully
        """
        entry = UsageEntry(
            timestamp=datetime.now(),
            provider=provider,
            model=model,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            cost_usd=cost_usd,
            agent_id=agent_id,
            task_id=task_id,
            tags=tags or {},
        )

        # Try MCP first
        if self.mcp_client is not None:
            try:
                await self.mcp_client.call_tool(
                    server="token-audit",
                    tool="record_usage",
                    arguments={
                        "timestamp": entry.timestamp.isoformat(),
                        "provider": entry.provider,
                        "model": entry.model,
                        "input_tokens": entry.input_tokens,
                        "output_tokens": entry.output_tokens,
                        "cost_usd": entry.cost_usd,
                        "agent_id": entry.agent_id,
                        "task_id": entry.task_id,
                        "tags": entry.tags,
                    },
                )
                return True
            except Exception as e:
                logger.error(f"Failed to record usage via MCP: {e}")

        # Fall back to database
        if self.db is not None and hasattr(self.db, 'record_usage'):
            try:
                self.db.record_usage(
                    timestamp=entry.timestamp.isoformat(),
                    provider=entry.provider,
                    model=entry.model,
                    input_tokens=entry.input_tokens,
                    output_tokens=entry.output_tokens,
                    cost_usd=entry.cost_usd,
                    agent_id=entry.agent_id,
                    task_id=entry.task_id,
                    tags=json.dumps(entry.tags),
                )
                return True
            except Exception as e:
                logger.error(f"Failed to record usage to database: {e}")

        return False

    async def check_thresholds(
        self,
        agent_id: Optional[str] = None,
    ) -> List[CostAlert]:
        """
        Check if any cost thresholds are exceeded.

        Args:
            agent_id: Check for specific agent

        Returns:
            List of CostAlert objects for exceeded thresholds
        """
        alerts = []

        # Check daily threshold
        daily_summary = await self.get_daily_summary(agent_id=agent_id)
        if daily_summary.total_cost_usd >= self._daily_threshold:
            alert = CostAlert(
                threshold_type="daily",
                threshold_value=self._daily_threshold,
                current_value=daily_summary.total_cost_usd,
                exceeded=True,
                message=f"Daily cost threshold exceeded: ${daily_summary.total_cost_usd:.2f} >= ${self._daily_threshold:.2f}",
            )
            alerts.append(alert)
            self._fire_alert(alert)

        # Check hourly threshold
        hourly_summary = await self.get_hourly_summary(agent_id=agent_id)
        if hourly_summary.total_cost_usd >= self._hourly_threshold:
            alert = CostAlert(
                threshold_type="hourly",
                threshold_value=self._hourly_threshold,
                current_value=hourly_summary.total_cost_usd,
                exceeded=True,
                message=f"Hourly cost threshold exceeded: ${hourly_summary.total_cost_usd:.2f} >= ${self._hourly_threshold:.2f}",
            )
            alerts.append(alert)
            self._fire_alert(alert)

        return alerts

    def check_request_threshold(self, cost: float) -> Optional[CostAlert]:
        """
        Check if a single request exceeds threshold.

        Args:
            cost: Cost of the request

        Returns:
            CostAlert if exceeded, None otherwise
        """
        if cost >= self._request_threshold:
            alert = CostAlert(
                threshold_type="per_request",
                threshold_value=self._request_threshold,
                current_value=cost,
                exceeded=True,
                message=f"Request cost threshold exceeded: ${cost:.2f} >= ${self._request_threshold:.2f}",
            )
            self._fire_alert(alert)
            return alert

        return None

    def set_threshold(self, threshold_type: str, value: float) -> None:
        """
        Set a cost threshold.

        Args:
            threshold_type: daily, hourly, or per_request
            value: Threshold value in USD
        """
        if threshold_type == "daily":
            self._daily_threshold = value
        elif threshold_type == "hourly":
            self._hourly_threshold = value
        elif threshold_type == "per_request":
            self._request_threshold = value
        else:
            raise ValueError(f"Unknown threshold type: {threshold_type}")

        logger.info(f"Set {threshold_type} threshold to ${value:.2f}")

    def on_alert(self, callback: callable) -> None:
        """
        Register callback for cost alerts.

        Args:
            callback: Function to call with CostAlert
        """
        self._alert_callbacks.append(callback)

    def _fire_alert(self, alert: CostAlert) -> None:
        """Fire alert to registered callbacks."""
        for callback in self._alert_callbacks:
            try:
                callback(alert)
            except Exception as e:
                logger.error(f"Alert callback failed: {e}")

    def clear_cache(self) -> None:
        """Clear the usage cache."""
        self._usage_cache.clear()

    async def get_agent_costs(
        self,
        start_date: Optional[date] = None,
        end_date: Optional[date] = None,
    ) -> Dict[str, float]:
        """
        Get total costs per agent for a date range.

        Args:
            start_date: Start of range (default: today)
            end_date: End of range (default: today)

        Returns:
            Dict mapping agent_id to total cost
        """
        if start_date is None:
            start_date = date.today()
        if end_date is None:
            end_date = start_date

        start_time = datetime.combine(start_date, datetime.min.time())
        end_time = datetime.combine(end_date, datetime.max.time())

        entries = await self.query_usage(start_time=start_time, end_time=end_time)

        costs: Dict[str, float] = {}
        for entry in entries:
            agent = entry.agent_id or "unknown"
            if agent not in costs:
                costs[agent] = 0.0
            costs[agent] += entry.cost_usd

        return costs

    async def get_model_costs(
        self,
        start_date: Optional[date] = None,
        end_date: Optional[date] = None,
    ) -> Dict[str, float]:
        """
        Get total costs per model for a date range.

        Args:
            start_date: Start of range
            end_date: End of range

        Returns:
            Dict mapping model name to total cost
        """
        if start_date is None:
            start_date = date.today()
        if end_date is None:
            end_date = start_date

        start_time = datetime.combine(start_date, datetime.min.time())
        end_time = datetime.combine(end_date, datetime.max.time())

        entries = await self.query_usage(start_time=start_time, end_time=end_time)

        costs: Dict[str, float] = {}
        for entry in entries:
            if entry.model not in costs:
                costs[entry.model] = 0.0
            costs[entry.model] += entry.cost_usd

        return costs

    def generate_report(self, summary: UsageSummary) -> str:
        """
        Generate a text report from a usage summary.

        Args:
            summary: UsageSummary to report

        Returns:
            Formatted text report
        """
        lines = [
            "=" * 60,
            "TOKEN USAGE REPORT",
            "=" * 60,
            f"Period: {summary.period_start.strftime('%Y-%m-%d %H:%M')} - {summary.period_end.strftime('%Y-%m-%d %H:%M')}",
            "",
            "TOTALS:",
            f"  Input tokens:  {summary.total_input_tokens:,}",
            f"  Output tokens: {summary.total_output_tokens:,}",
            f"  Total cost:    ${summary.total_cost_usd:.4f}",
            f"  Requests:      {summary.total_requests:,}",
            "",
        ]

        if summary.by_provider:
            lines.append("BY PROVIDER:")
            for provider, data in sorted(summary.by_provider.items()):
                lines.append(f"  {provider}:")
                lines.append(f"    Tokens: {data['input_tokens']:,} in / {data['output_tokens']:,} out")
                lines.append(f"    Cost:   ${data['cost_usd']:.4f}")
                lines.append(f"    Reqs:   {data['requests']:,}")
            lines.append("")

        if summary.by_model:
            lines.append("BY MODEL:")
            for model, data in sorted(summary.by_model.items()):
                lines.append(f"  {model}:")
                lines.append(f"    Tokens: {data['input_tokens']:,} in / {data['output_tokens']:,} out")
                lines.append(f"    Cost:   ${data['cost_usd']:.4f}")
            lines.append("")

        if summary.by_agent:
            lines.append("BY AGENT:")
            for agent, data in sorted(summary.by_agent.items()):
                lines.append(f"  {agent}:")
                lines.append(f"    Tokens: {data['input_tokens']:,} in / {data['output_tokens']:,} out")
                lines.append(f"    Cost:   ${data['cost_usd']:.4f}")
            lines.append("")

        lines.append("=" * 60)

        return "\n".join(lines)


# Singleton instance
_token_audit_client: Optional[TokenAuditClient] = None


def get_token_audit_client() -> Optional[TokenAuditClient]:
    """Get the global Token Audit client instance."""
    return _token_audit_client


def set_token_audit_client(client: TokenAuditClient) -> None:
    """Set the global Token Audit client instance."""
    global _token_audit_client
    _token_audit_client = client
