"""
Trace Storage - Persistence layer for traces.

Provides:
- In-memory storage for testing
- SQLite storage for production
- Query capabilities for trace retrieval
"""

import json
import logging
import sqlite3
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional

from .models import Trace, Span, SpanKind, SpanStatus

logger = logging.getLogger(__name__)


class TraceStorage(ABC):
    """Abstract base class for trace storage backends."""

    @abstractmethod
    def save_trace(self, trace: Trace) -> None:
        """Save a trace."""
        pass

    @abstractmethod
    def get_trace(self, trace_id: str) -> Optional[Trace]:
        """Get a trace by ID."""
        pass

    @abstractmethod
    def list_traces(
        self,
        agent_id: Optional[str] = None,
        task_id: Optional[str] = None,
        status: Optional[SpanStatus] = None,
        since: Optional[datetime] = None,
        limit: int = 100,
    ) -> list[Trace]:
        """List traces with optional filters."""
        pass

    @abstractmethod
    def delete_trace(self, trace_id: str) -> bool:
        """Delete a trace."""
        pass


class InMemoryTraceStorage(TraceStorage):
    """In-memory trace storage for testing."""

    def __init__(self, max_traces: int = 1000):
        """
        Initialize in-memory storage.

        Args:
            max_traces: Maximum traces to store before evicting oldest
        """
        self._traces: dict[str, Trace] = {}
        self._max_traces = max_traces

    def save_trace(self, trace: Trace) -> None:
        """Save a trace."""
        # Evict oldest if at capacity
        if len(self._traces) >= self._max_traces:
            oldest_id = min(
                self._traces.keys(),
                key=lambda k: self._traces[k].start_time,
            )
            del self._traces[oldest_id]

        self._traces[trace.trace_id] = trace

    def get_trace(self, trace_id: str) -> Optional[Trace]:
        """Get a trace by ID."""
        return self._traces.get(trace_id)

    def list_traces(
        self,
        agent_id: Optional[str] = None,
        task_id: Optional[str] = None,
        status: Optional[SpanStatus] = None,
        since: Optional[datetime] = None,
        limit: int = 100,
    ) -> list[Trace]:
        """List traces with optional filters."""
        results = list(self._traces.values())

        if agent_id:
            results = [t for t in results if t.agent_id == agent_id]
        if task_id:
            results = [t for t in results if t.task_id == task_id]
        if status:
            results = [t for t in results if t.status == status]
        if since:
            results = [t for t in results if t.start_time >= since]

        # Sort by start time descending
        results.sort(key=lambda t: t.start_time, reverse=True)
        return results[:limit]

    def delete_trace(self, trace_id: str) -> bool:
        """Delete a trace."""
        if trace_id in self._traces:
            del self._traces[trace_id]
            return True
        return False

    def clear(self) -> None:
        """Clear all traces."""
        self._traces.clear()


class SQLiteTraceStorage(TraceStorage):
    """SQLite-based trace storage for production use."""

    def __init__(self, db_path: str = "traces.db"):
        """
        Initialize SQLite storage.

        Args:
            db_path: Path to SQLite database file
        """
        self._db_path = db_path
        self._init_db()

    def _init_db(self) -> None:
        """Initialize database schema."""
        with sqlite3.connect(self._db_path) as conn:
            conn.execute("""
                CREATE TABLE IF NOT EXISTS traces (
                    trace_id TEXT PRIMARY KEY,
                    task_id TEXT NOT NULL,
                    agent_id TEXT NOT NULL,
                    start_time TEXT NOT NULL,
                    end_time TEXT,
                    status TEXT NOT NULL,
                    total_tokens INTEGER DEFAULT 0,
                    total_cost_usd REAL DEFAULT 0.0,
                    total_latency_ms REAL DEFAULT 0.0,
                    span_count INTEGER DEFAULT 0,
                    tags TEXT,
                    metadata TEXT,
                    root_span TEXT,
                    created_at TEXT DEFAULT CURRENT_TIMESTAMP
                )
            """)

            conn.execute("""
                CREATE INDEX IF NOT EXISTS idx_traces_agent
                ON traces(agent_id)
            """)

            conn.execute("""
                CREATE INDEX IF NOT EXISTS idx_traces_task
                ON traces(task_id)
            """)

            conn.execute("""
                CREATE INDEX IF NOT EXISTS idx_traces_start_time
                ON traces(start_time)
            """)

            conn.commit()

    def save_trace(self, trace: Trace) -> None:
        """Save a trace."""
        with sqlite3.connect(self._db_path) as conn:
            conn.execute(
                """
                INSERT OR REPLACE INTO traces (
                    trace_id, task_id, agent_id, start_time, end_time,
                    status, total_tokens, total_cost_usd, total_latency_ms,
                    span_count, tags, metadata, root_span
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                """,
                (
                    trace.trace_id,
                    trace.task_id,
                    trace.agent_id,
                    trace.start_time.isoformat(),
                    trace.end_time.isoformat() if trace.end_time else None,
                    trace.status.value,
                    trace.total_tokens,
                    trace.total_cost_usd,
                    trace.total_latency_ms,
                    trace.span_count,
                    json.dumps(trace.tags),
                    json.dumps(trace.metadata),
                    json.dumps(trace.root_span.to_dict()) if trace.root_span else None,
                ),
            )
            conn.commit()

    def get_trace(self, trace_id: str) -> Optional[Trace]:
        """Get a trace by ID."""
        with sqlite3.connect(self._db_path) as conn:
            conn.row_factory = sqlite3.Row
            cursor = conn.execute(
                "SELECT * FROM traces WHERE trace_id = ?",
                (trace_id,),
            )
            row = cursor.fetchone()

            if not row:
                return None

            return self._row_to_trace(row)

    def list_traces(
        self,
        agent_id: Optional[str] = None,
        task_id: Optional[str] = None,
        status: Optional[SpanStatus] = None,
        since: Optional[datetime] = None,
        limit: int = 100,
    ) -> list[Trace]:
        """List traces with optional filters."""
        query = "SELECT * FROM traces WHERE 1=1"
        params: list[Any] = []

        if agent_id:
            query += " AND agent_id = ?"
            params.append(agent_id)
        if task_id:
            query += " AND task_id = ?"
            params.append(task_id)
        if status:
            query += " AND status = ?"
            params.append(status.value)
        if since:
            query += " AND start_time >= ?"
            params.append(since.isoformat())

        query += " ORDER BY start_time DESC LIMIT ?"
        params.append(limit)

        with sqlite3.connect(self._db_path) as conn:
            conn.row_factory = sqlite3.Row
            cursor = conn.execute(query, params)
            return [self._row_to_trace(row) for row in cursor.fetchall()]

    def delete_trace(self, trace_id: str) -> bool:
        """Delete a trace."""
        with sqlite3.connect(self._db_path) as conn:
            cursor = conn.execute(
                "DELETE FROM traces WHERE trace_id = ?",
                (trace_id,),
            )
            conn.commit()
            return cursor.rowcount > 0

    def _row_to_trace(self, row: sqlite3.Row) -> Trace:
        """Convert database row to Trace object."""
        trace = Trace(
            trace_id=row["trace_id"],
            task_id=row["task_id"],
            agent_id=row["agent_id"],
            start_time=datetime.fromisoformat(row["start_time"]),
            end_time=datetime.fromisoformat(row["end_time"]) if row["end_time"] else None,
            status=SpanStatus(row["status"]),
            total_tokens=row["total_tokens"],
            total_cost_usd=row["total_cost_usd"],
            total_latency_ms=row["total_latency_ms"],
            span_count=row["span_count"],
            tags=json.loads(row["tags"]) if row["tags"] else {},
            metadata=json.loads(row["metadata"]) if row["metadata"] else {},
        )
        if row["root_span"]:
            trace.root_span = Span.from_dict(json.loads(row["root_span"]))
        return trace

    def get_statistics(
        self,
        since: Optional[datetime] = None,
    ) -> dict[str, Any]:
        """Get trace statistics."""
        since = since or (datetime.now() - timedelta(days=7))

        with sqlite3.connect(self._db_path) as conn:
            # Total counts
            cursor = conn.execute(
                """
                SELECT
                    COUNT(*) as total_traces,
                    SUM(total_tokens) as total_tokens,
                    SUM(total_cost_usd) as total_cost,
                    AVG(total_latency_ms) as avg_latency,
                    SUM(span_count) as total_spans
                FROM traces
                WHERE start_time >= ?
                """,
                (since.isoformat(),),
            )
            row = cursor.fetchone()

            # Status breakdown
            cursor = conn.execute(
                """
                SELECT status, COUNT(*) as count
                FROM traces
                WHERE start_time >= ?
                GROUP BY status
                """,
                (since.isoformat(),),
            )
            status_counts = {r[0]: r[1] for r in cursor.fetchall()}

            # Agent breakdown
            cursor = conn.execute(
                """
                SELECT agent_id, COUNT(*) as count
                FROM traces
                WHERE start_time >= ?
                GROUP BY agent_id
                """,
                (since.isoformat(),),
            )
            agent_counts = {r[0]: r[1] for r in cursor.fetchall()}

            return {
                "total_traces": row[0] or 0,
                "total_tokens": row[1] or 0,
                "total_cost_usd": row[2] or 0.0,
                "avg_latency_ms": row[3] or 0.0,
                "total_spans": row[4] or 0,
                "by_status": status_counts,
                "by_agent": agent_counts,
                "since": since.isoformat(),
            }

    def cleanup_old_traces(self, older_than_days: int = 30) -> int:
        """
        Delete traces older than specified days.

        Returns number of deleted traces.
        """
        cutoff = datetime.now() - timedelta(days=older_than_days)

        with sqlite3.connect(self._db_path) as conn:
            cursor = conn.execute(
                "DELETE FROM traces WHERE start_time < ?",
                (cutoff.isoformat(),),
            )
            conn.commit()
            return cursor.rowcount


# Module-level storage instance
_storage: Optional[TraceStorage] = None


def get_trace_storage() -> TraceStorage:
    """Get or create the global trace storage."""
    global _storage
    if _storage is None:
        _storage = InMemoryTraceStorage()
    return _storage


def set_trace_storage(storage: Optional[TraceStorage]) -> None:
    """Set the global trace storage."""
    global _storage
    _storage = storage
