"""
Trace Data Models - Data structures for session tracing.

Provides:
- Span - Single unit of work in a trace
- Trace - Complete execution trace
- SpanKind - Types of spans
- TimelineEvent - Events for replay
"""

from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Optional


class SpanKind(Enum):
    """Type of trace span."""

    TASK = "task"  # Top-level task execution
    AGENT = "agent"  # Agent operation
    LLM_CALL = "llm_call"  # LLM API call
    TOOL_CALL = "tool_call"  # Tool execution
    APPROVAL = "approval"  # User approval request
    MEMORY_READ = "memory_read"  # Memory access
    MEMORY_WRITE = "memory_write"  # Memory write
    FILE_READ = "file_read"  # File read operation
    FILE_WRITE = "file_write"  # File write operation
    COMMAND = "command"  # Shell command
    WORKFLOW = "workflow"  # Workflow step
    HANDOFF = "handoff"  # Agent handoff
    VOTING = "voting"  # Voting session


class SpanStatus(Enum):
    """Status of a span."""

    RUNNING = "running"
    SUCCESS = "success"
    ERROR = "error"
    CANCELLED = "cancelled"
    TIMEOUT = "timeout"


@dataclass
class Span:
    """
    A single span in a trace.

    Represents a unit of work with timing, metrics, and nested children.
    Inspired by OpenTelemetry and LangSmith trace models.
    """

    span_id: str
    trace_id: str
    name: str
    kind: SpanKind
    start_time: datetime

    # Parent relationship
    parent_span_id: Optional[str] = None

    # Timing
    end_time: Optional[datetime] = None
    status: SpanStatus = SpanStatus.RUNNING

    # Metrics
    input_tokens: int = 0
    output_tokens: int = 0
    cost_usd: float = 0.0
    latency_ms: float = 0.0

    # Content
    input_data: dict[str, Any] = field(default_factory=dict)
    output_data: dict[str, Any] = field(default_factory=dict)
    metadata: dict[str, Any] = field(default_factory=dict)
    error: Optional[str] = None

    # Nested spans
    children: list["Span"] = field(default_factory=list)

    @property
    def duration(self) -> timedelta:
        """Get span duration."""
        if self.end_time:
            return self.end_time - self.start_time
        return datetime.now() - self.start_time

    @property
    def duration_ms(self) -> float:
        """Get duration in milliseconds."""
        return self.duration.total_seconds() * 1000

    @property
    def total_tokens(self) -> int:
        """Get total tokens (input + output)."""
        return self.input_tokens + self.output_tokens

    @property
    def is_complete(self) -> bool:
        """Check if span is complete."""
        return self.end_time is not None

    def add_child(self, child: "Span") -> None:
        """Add a child span."""
        child.parent_span_id = self.span_id
        self.children.append(child)

    def complete(
        self,
        status: SpanStatus = SpanStatus.SUCCESS,
        output_data: Optional[dict] = None,
        error: Optional[str] = None,
    ) -> None:
        """Mark span as complete."""
        self.end_time = datetime.now()
        self.status = status
        self.latency_ms = self.duration_ms
        if output_data:
            self.output_data = output_data
        if error:
            self.error = error
            self.status = SpanStatus.ERROR

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "span_id": self.span_id,
            "trace_id": self.trace_id,
            "parent_span_id": self.parent_span_id,
            "name": self.name,
            "kind": self.kind.value,
            "start_time": self.start_time.isoformat(),
            "end_time": self.end_time.isoformat() if self.end_time else None,
            "status": self.status.value,
            "input_tokens": self.input_tokens,
            "output_tokens": self.output_tokens,
            "cost_usd": self.cost_usd,
            "latency_ms": self.latency_ms,
            "input_data": self.input_data,
            "output_data": self.output_data,
            "metadata": self.metadata,
            "error": self.error,
            "children": [c.to_dict() for c in self.children],
        }

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "Span":
        """Create from dictionary."""
        children_data = data.pop("children", [])
        span = cls(
            span_id=data["span_id"],
            trace_id=data["trace_id"],
            name=data["name"],
            kind=SpanKind(data["kind"]),
            start_time=datetime.fromisoformat(data["start_time"]),
            parent_span_id=data.get("parent_span_id"),
            end_time=datetime.fromisoformat(data["end_time"]) if data.get("end_time") else None,
            status=SpanStatus(data.get("status", "running")),
            input_tokens=data.get("input_tokens", 0),
            output_tokens=data.get("output_tokens", 0),
            cost_usd=data.get("cost_usd", 0.0),
            latency_ms=data.get("latency_ms", 0.0),
            input_data=data.get("input_data", {}),
            output_data=data.get("output_data", {}),
            metadata=data.get("metadata", {}),
            error=data.get("error"),
        )
        span.children = [cls.from_dict(c) for c in children_data]
        return span


@dataclass
class Trace:
    """
    Complete trace of a task execution.

    Contains the root span and aggregated metrics for a full task execution.
    """

    trace_id: str
    task_id: str
    agent_id: str
    start_time: datetime

    # Status
    end_time: Optional[datetime] = None
    status: SpanStatus = SpanStatus.RUNNING

    # Root span
    root_span: Optional[Span] = None

    # Aggregated metrics
    total_tokens: int = 0
    total_cost_usd: float = 0.0
    total_latency_ms: float = 0.0
    span_count: int = 0

    # Categorization
    tags: dict[str, str] = field(default_factory=dict)
    metadata: dict[str, Any] = field(default_factory=dict)

    @property
    def duration(self) -> timedelta:
        """Get trace duration."""
        if self.end_time:
            return self.end_time - self.start_time
        return datetime.now() - self.start_time

    @property
    def duration_ms(self) -> float:
        """Get duration in milliseconds."""
        return self.duration.total_seconds() * 1000

    @property
    def is_complete(self) -> bool:
        """Check if trace is complete."""
        return self.end_time is not None

    def aggregate_metrics(self) -> None:
        """Aggregate metrics from all spans."""
        if not self.root_span:
            return

        self.total_tokens = 0
        self.total_cost_usd = 0.0
        self.span_count = 0

        def aggregate_span(span: Span) -> None:
            self.total_tokens += span.total_tokens
            self.total_cost_usd += span.cost_usd
            self.span_count += 1
            for child in span.children:
                aggregate_span(child)

        aggregate_span(self.root_span)
        self.total_latency_ms = self.duration_ms

    def complete(self, status: SpanStatus = SpanStatus.SUCCESS) -> None:
        """Mark trace as complete."""
        self.end_time = datetime.now()
        self.status = status
        self.aggregate_metrics()

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "trace_id": self.trace_id,
            "task_id": self.task_id,
            "agent_id": self.agent_id,
            "start_time": self.start_time.isoformat(),
            "end_time": self.end_time.isoformat() if self.end_time else None,
            "status": self.status.value,
            "root_span": self.root_span.to_dict() if self.root_span else None,
            "total_tokens": self.total_tokens,
            "total_cost_usd": self.total_cost_usd,
            "total_latency_ms": self.total_latency_ms,
            "span_count": self.span_count,
            "tags": self.tags,
            "metadata": self.metadata,
        }

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "Trace":
        """Create from dictionary."""
        trace = cls(
            trace_id=data["trace_id"],
            task_id=data["task_id"],
            agent_id=data["agent_id"],
            start_time=datetime.fromisoformat(data["start_time"]),
            end_time=datetime.fromisoformat(data["end_time"]) if data.get("end_time") else None,
            status=SpanStatus(data.get("status", "running")),
            total_tokens=data.get("total_tokens", 0),
            total_cost_usd=data.get("total_cost_usd", 0.0),
            total_latency_ms=data.get("total_latency_ms", 0.0),
            span_count=data.get("span_count", 0),
            tags=data.get("tags", {}),
            metadata=data.get("metadata", {}),
        )
        if data.get("root_span"):
            trace.root_span = Span.from_dict(data["root_span"])
        return trace


@dataclass
class TimelineEvent:
    """Event in an execution timeline for replay."""

    timestamp: datetime
    event_type: str  # "span_start", "span_end", "llm_response", etc.
    span_id: str
    span_name: str
    span_kind: SpanKind
    depth: int = 0
    data: dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "timestamp": self.timestamp.isoformat(),
            "event_type": self.event_type,
            "span_id": self.span_id,
            "span_name": self.span_name,
            "span_kind": self.span_kind.value,
            "depth": self.depth,
            "data": self.data,
        }


@dataclass
class DecisionPoint:
    """A decision point where an agent made a choice."""

    span_id: str
    span_kind: SpanKind
    timestamp: datetime
    input_context: dict[str, Any]
    decision: str
    alternatives: list[str] = field(default_factory=list)
    rationale: Optional[str] = None
    confidence: float = 1.0

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "span_id": self.span_id,
            "span_kind": self.span_kind.value,
            "timestamp": self.timestamp.isoformat(),
            "input_context": self.input_context,
            "decision": self.decision,
            "alternatives": self.alternatives,
            "rationale": self.rationale,
            "confidence": self.confidence,
        }
