"""
Audit Trail - Comprehensive event logging and audit system.

Implements:
- Structured event logging
- Log rotation
- Audit query API
- Event export functionality
- Compliance-ready audit trail
"""

from dataclasses import dataclass, field, asdict
from datetime import datetime, timedelta, date
from enum import Enum
from typing import Optional, Dict, List, Any, Callable, Union
import json
import os
import gzip
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


class EventType(Enum):
    """Type of audited event."""

    # Agent events
    AGENT_STARTED = "agent_started"
    AGENT_STOPPED = "agent_stopped"
    AGENT_STUCK = "agent_stuck"
    AGENT_RECOVERED = "agent_recovered"

    # Task events
    TASK_CREATED = "task_created"
    TASK_ASSIGNED = "task_assigned"
    TASK_STARTED = "task_started"
    TASK_COMPLETED = "task_completed"
    TASK_FAILED = "task_failed"
    TASK_CANCELLED = "task_cancelled"

    # Approval events
    APPROVAL_REQUESTED = "approval_requested"
    APPROVAL_GRANTED = "approval_granted"
    APPROVAL_DENIED = "approval_denied"
    APPROVAL_TIMEOUT = "approval_timeout"

    # Action events
    ACTION_EXECUTED = "action_executed"
    ACTION_BLOCKED = "action_blocked"
    ACTION_ESCALATED = "action_escalated"

    # Risk events
    RISK_DETECTED = "risk_detected"
    RISK_MITIGATED = "risk_mitigated"

    # Budget events
    BUDGET_WARNING = "budget_warning"
    BUDGET_EXCEEDED = "budget_exceeded"
    BUDGET_RESET = "budget_reset"

    # System events
    SYSTEM_STARTUP = "system_startup"
    SYSTEM_SHUTDOWN = "system_shutdown"
    SYSTEM_ERROR = "system_error"
    CONFIG_CHANGED = "config_changed"

    # Merge events
    MERGE_REQUESTED = "merge_requested"
    MERGE_COMPLETED = "merge_completed"
    MERGE_BLOCKED = "merge_blocked"

    # Security events
    AUTH_SUCCESS = "auth_success"
    AUTH_FAILURE = "auth_failure"
    ACCESS_DENIED = "access_denied"


class EventSeverity(Enum):
    """Severity level of event."""
    DEBUG = "debug"
    INFO = "info"
    WARNING = "warning"
    ERROR = "error"
    CRITICAL = "critical"


@dataclass
class AuditEvent:
    """A single audit event."""

    event_id: str
    event_type: EventType
    severity: EventSeverity
    timestamp: datetime

    # Context
    message: str
    agent_id: Optional[str] = None
    task_id: Optional[str] = None
    user_id: Optional[str] = None

    # Details
    details: Dict[str, Any] = field(default_factory=dict)

    # Source
    source_component: str = "orchestrator"
    source_host: str = ""

    # Correlation
    correlation_id: Optional[str] = None
    parent_event_id: Optional[str] = None

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return {
            "event_id": self.event_id,
            "event_type": self.event_type.value,
            "severity": self.severity.value,
            "timestamp": self.timestamp.isoformat(),
            "message": self.message,
            "agent_id": self.agent_id,
            "task_id": self.task_id,
            "user_id": self.user_id,
            "details": self.details,
            "source_component": self.source_component,
            "source_host": self.source_host,
            "correlation_id": self.correlation_id,
            "parent_event_id": self.parent_event_id,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'AuditEvent':
        """Create from dictionary."""
        return cls(
            event_id=data["event_id"],
            event_type=EventType(data["event_type"]),
            severity=EventSeverity(data["severity"]),
            timestamp=datetime.fromisoformat(data["timestamp"]),
            message=data["message"],
            agent_id=data.get("agent_id"),
            task_id=data.get("task_id"),
            user_id=data.get("user_id"),
            details=data.get("details", {}),
            source_component=data.get("source_component", "orchestrator"),
            source_host=data.get("source_host", ""),
            correlation_id=data.get("correlation_id"),
            parent_event_id=data.get("parent_event_id"),
        )


@dataclass
class AuditConfig:
    """Configuration for audit trail."""

    # Storage
    log_directory: str = "logs/audit"
    max_file_size_mb: int = 100
    max_files: int = 30  # Keep 30 days of logs
    compress_old_logs: bool = True

    # Retention
    retention_days: int = 90

    # What to log
    log_debug_events: bool = False
    log_to_console: bool = False
    log_to_database: bool = True

    # Format
    pretty_print: bool = False


class AuditTrail:
    """
    Comprehensive audit trail for the orchestrator.

    Records all significant events for compliance,
    debugging, and analysis.
    """

    def __init__(
        self,
        db: Any = None,
        config: Optional[AuditConfig] = None,
    ):
        """
        Initialize audit trail.

        Args:
            db: Database for persistence
            config: Audit configuration
        """
        self.db = db
        self.config = config or AuditConfig()

        # Ensure log directory exists
        self._log_dir = Path(self.config.log_directory)
        self._log_dir.mkdir(parents=True, exist_ok=True)

        # Current log file
        self._current_file: Optional[Path] = None
        self._current_file_size = 0

        # Event ID counter
        self._next_event_id = 1

        # In-memory buffer for recent events
        self._recent_events: List[AuditEvent] = []
        self._max_recent = 1000

        # Callbacks for event notifications
        self._event_callbacks: List[Callable[[AuditEvent], None]] = []

        # Get hostname
        import socket
        self._hostname = socket.gethostname()

    def _generate_event_id(self) -> str:
        """Generate unique event ID."""
        event_id = f"evt-{self._next_event_id:010d}"
        self._next_event_id += 1
        return event_id

    def _get_current_log_file(self) -> Path:
        """Get or create current log file."""
        today = date.today().isoformat()
        expected_file = self._log_dir / f"audit-{today}.jsonl"

        if self._current_file != expected_file:
            # New day, new file
            self._current_file = expected_file
            self._current_file_size = (
                expected_file.stat().st_size
                if expected_file.exists() else 0
            )

        # Check if rotation needed
        if self._current_file_size > self.config.max_file_size_mb * 1024 * 1024:
            self._rotate_log_file()

        return self._current_file

    def _rotate_log_file(self) -> None:
        """Rotate the current log file."""
        if not self._current_file or not self._current_file.exists():
            return

        # Find next available number
        base_name = self._current_file.stem
        for i in range(1, 100):
            rotated_name = f"{base_name}.{i}.jsonl"
            if self.config.compress_old_logs:
                rotated_name += ".gz"

            rotated_path = self._log_dir / rotated_name
            if not rotated_path.exists():
                if self.config.compress_old_logs:
                    # Compress and rename
                    with open(self._current_file, 'rb') as f_in:
                        with gzip.open(rotated_path, 'wb') as f_out:
                            f_out.writelines(f_in)
                    self._current_file.unlink()
                else:
                    self._current_file.rename(rotated_path)

                # Create new file
                self._current_file.touch()
                self._current_file_size = 0
                break

        # Cleanup old files
        self._cleanup_old_files()

    def _cleanup_old_files(self) -> None:
        """Remove log files older than retention period."""
        cutoff = datetime.now() - timedelta(days=self.config.retention_days)

        for log_file in self._log_dir.glob("audit-*.jsonl*"):
            try:
                # Parse date from filename
                date_part = log_file.name.split("-")[1].split(".")[0]
                file_date = datetime.fromisoformat(date_part)

                if file_date < cutoff:
                    log_file.unlink()
                    logger.info(f"Deleted old audit log: {log_file.name}")

            except (ValueError, IndexError):
                continue

    def log_event(
        self,
        event_type: EventType,
        message: str,
        severity: EventSeverity = EventSeverity.INFO,
        agent_id: Optional[str] = None,
        task_id: Optional[str] = None,
        user_id: Optional[str] = None,
        details: Optional[Dict[str, Any]] = None,
        correlation_id: Optional[str] = None,
    ) -> AuditEvent:
        """
        Log an audit event.

        Args:
            event_type: Type of event
            message: Human-readable message
            severity: Event severity
            agent_id: Related agent
            task_id: Related task
            user_id: Related user
            details: Additional details
            correlation_id: For correlating related events

        Returns:
            The created AuditEvent
        """
        # Skip debug events if not configured
        if severity == EventSeverity.DEBUG and not self.config.log_debug_events:
            return None

        event = AuditEvent(
            event_id=self._generate_event_id(),
            event_type=event_type,
            severity=severity,
            timestamp=datetime.now(),
            message=message,
            agent_id=agent_id,
            task_id=task_id,
            user_id=user_id,
            details=details or {},
            source_host=self._hostname,
            correlation_id=correlation_id,
        )

        # Write to file
        self._write_to_file(event)

        # Write to database
        if self.config.log_to_database and self.db:
            self._write_to_database(event)

        # Log to console
        if self.config.log_to_console:
            self._log_to_console(event)

        # Add to recent events
        self._recent_events.append(event)
        if len(self._recent_events) > self._max_recent:
            self._recent_events.pop(0)

        # Fire callbacks
        for callback in self._event_callbacks:
            try:
                callback(event)
            except Exception as e:
                logger.error(f"Audit callback failed: {e}")

        return event

    def _write_to_file(self, event: AuditEvent) -> None:
        """Write event to log file."""
        log_file = self._get_current_log_file()

        try:
            line = json.dumps(
                event.to_dict(),
                indent=2 if self.config.pretty_print else None,
            ) + "\n"

            with open(log_file, 'a') as f:
                f.write(line)

            self._current_file_size += len(line.encode())

        except Exception as e:
            logger.error(f"Failed to write audit event to file: {e}")

    def _write_to_database(self, event: AuditEvent) -> None:
        """Write event to database."""
        if not hasattr(self.db, 'insert_audit_event'):
            return

        try:
            self.db.insert_audit_event(event.to_dict())
        except Exception as e:
            logger.error(f"Failed to write audit event to database: {e}")

    def _log_to_console(self, event: AuditEvent) -> None:
        """Log event to console."""
        log_level = {
            EventSeverity.DEBUG: logging.DEBUG,
            EventSeverity.INFO: logging.INFO,
            EventSeverity.WARNING: logging.WARNING,
            EventSeverity.ERROR: logging.ERROR,
            EventSeverity.CRITICAL: logging.CRITICAL,
        }.get(event.severity, logging.INFO)

        logger.log(
            log_level,
            f"[AUDIT] {event.event_type.value}: {event.message}"
        )

    def on_event(self, callback: Callable[[AuditEvent], None]) -> None:
        """Register callback for audit events."""
        self._event_callbacks.append(callback)

    # Convenience methods for common events

    def log_agent_started(
        self,
        agent_id: str,
        details: Optional[Dict[str, Any]] = None,
    ) -> AuditEvent:
        """Log agent started event."""
        return self.log_event(
            EventType.AGENT_STARTED,
            f"Agent {agent_id} started",
            agent_id=agent_id,
            details=details,
        )

    def log_agent_stopped(
        self,
        agent_id: str,
        reason: str = "",
        details: Optional[Dict[str, Any]] = None,
    ) -> AuditEvent:
        """Log agent stopped event."""
        return self.log_event(
            EventType.AGENT_STOPPED,
            f"Agent {agent_id} stopped: {reason}",
            agent_id=agent_id,
            details=details,
        )

    def log_task_completed(
        self,
        task_id: str,
        agent_id: str,
        success: bool,
        details: Optional[Dict[str, Any]] = None,
    ) -> AuditEvent:
        """Log task completion."""
        return self.log_event(
            EventType.TASK_COMPLETED if success else EventType.TASK_FAILED,
            f"Task {task_id} {'completed' if success else 'failed'} by {agent_id}",
            severity=EventSeverity.INFO if success else EventSeverity.WARNING,
            agent_id=agent_id,
            task_id=task_id,
            details=details,
        )

    def log_approval(
        self,
        action: str,  # "requested", "granted", "denied", "timeout"
        agent_id: str,
        target: str,
        user_id: Optional[str] = None,
        details: Optional[Dict[str, Any]] = None,
    ) -> AuditEvent:
        """Log approval event."""
        event_types = {
            "requested": EventType.APPROVAL_REQUESTED,
            "granted": EventType.APPROVAL_GRANTED,
            "denied": EventType.APPROVAL_DENIED,
            "timeout": EventType.APPROVAL_TIMEOUT,
        }

        return self.log_event(
            event_types.get(action, EventType.APPROVAL_REQUESTED),
            f"Approval {action} for {target} by {agent_id}",
            agent_id=agent_id,
            user_id=user_id,
            details={"target": target, **(details or {})},
        )

    def log_action_blocked(
        self,
        agent_id: str,
        action_type: str,
        target: str,
        reason: str,
        details: Optional[Dict[str, Any]] = None,
    ) -> AuditEvent:
        """Log blocked action."""
        return self.log_event(
            EventType.ACTION_BLOCKED,
            f"Action blocked: {action_type} on {target} - {reason}",
            severity=EventSeverity.WARNING,
            agent_id=agent_id,
            details={"action_type": action_type, "target": target, "reason": reason, **(details or {})},
        )

    def log_system_error(
        self,
        error: str,
        component: str = "orchestrator",
        details: Optional[Dict[str, Any]] = None,
    ) -> AuditEvent:
        """Log system error."""
        return self.log_event(
            EventType.SYSTEM_ERROR,
            f"System error in {component}: {error}",
            severity=EventSeverity.ERROR,
            details={"component": component, "error": error, **(details or {})},
        )

    # Query methods

    def get_recent_events(
        self,
        limit: int = 100,
        event_type: Optional[EventType] = None,
        severity: Optional[EventSeverity] = None,
        agent_id: Optional[str] = None,
    ) -> List[AuditEvent]:
        """Get recent events from memory."""
        events = self._recent_events

        if event_type:
            events = [e for e in events if e.event_type == event_type]

        if severity:
            events = [e for e in events if e.severity == severity]

        if agent_id:
            events = [e for e in events if e.agent_id == agent_id]

        return events[-limit:]

    def query_events(
        self,
        start_time: Optional[datetime] = None,
        end_time: Optional[datetime] = None,
        event_types: Optional[List[EventType]] = None,
        severities: Optional[List[EventSeverity]] = None,
        agent_id: Optional[str] = None,
        task_id: Optional[str] = None,
        limit: int = 1000,
    ) -> List[AuditEvent]:
        """
        Query events from log files.

        Args:
            start_time: Start of time range
            end_time: End of time range
            event_types: Filter by event types
            severities: Filter by severities
            agent_id: Filter by agent
            task_id: Filter by task
            limit: Maximum events to return

        Returns:
            List of matching AuditEvents
        """
        events = []

        # Determine which files to search
        if start_time is None:
            start_time = datetime.now() - timedelta(days=7)
        if end_time is None:
            end_time = datetime.now()

        start_date = start_time.date()
        end_date = end_time.date()

        # Search log files
        current_date = start_date
        while current_date <= end_date and len(events) < limit:
            log_file = self._log_dir / f"audit-{current_date.isoformat()}.jsonl"

            if log_file.exists():
                events.extend(
                    self._read_events_from_file(
                        log_file,
                        start_time,
                        end_time,
                        event_types,
                        severities,
                        agent_id,
                        task_id,
                        limit - len(events),
                    )
                )

            current_date += timedelta(days=1)

        return events[:limit]

    def _read_events_from_file(
        self,
        log_file: Path,
        start_time: datetime,
        end_time: datetime,
        event_types: Optional[List[EventType]],
        severities: Optional[List[EventSeverity]],
        agent_id: Optional[str],
        task_id: Optional[str],
        limit: int,
    ) -> List[AuditEvent]:
        """Read and filter events from a log file."""
        events = []

        try:
            with open(log_file, 'r') as f:
                for line in f:
                    if len(events) >= limit:
                        break

                    try:
                        data = json.loads(line)
                        event = AuditEvent.from_dict(data)

                        # Apply filters
                        if event.timestamp < start_time or event.timestamp > end_time:
                            continue

                        if event_types and event.event_type not in event_types:
                            continue

                        if severities and event.severity not in severities:
                            continue

                        if agent_id and event.agent_id != agent_id:
                            continue

                        if task_id and event.task_id != task_id:
                            continue

                        events.append(event)

                    except (json.JSONDecodeError, KeyError, ValueError):
                        continue

        except Exception as e:
            logger.error(f"Error reading audit log {log_file}: {e}")

        return events

    def export_events(
        self,
        output_file: str,
        start_time: Optional[datetime] = None,
        end_time: Optional[datetime] = None,
        format: str = "jsonl",  # jsonl, json, csv
    ) -> int:
        """
        Export events to a file.

        Args:
            output_file: Output file path
            start_time: Start of time range
            end_time: End of time range
            format: Output format

        Returns:
            Number of events exported
        """
        events = self.query_events(
            start_time=start_time,
            end_time=end_time,
            limit=100000,
        )

        output_path = Path(output_file)

        if format == "jsonl":
            with open(output_path, 'w') as f:
                for event in events:
                    f.write(json.dumps(event.to_dict()) + "\n")

        elif format == "json":
            with open(output_path, 'w') as f:
                json.dump([e.to_dict() for e in events], f, indent=2)

        elif format == "csv":
            import csv
            with open(output_path, 'w', newline='') as f:
                if events:
                    writer = csv.DictWriter(f, fieldnames=events[0].to_dict().keys())
                    writer.writeheader()
                    for event in events:
                        writer.writerow(event.to_dict())

        logger.info(f"Exported {len(events)} events to {output_file}")
        return len(events)

    def get_statistics(
        self,
        start_time: Optional[datetime] = None,
        end_time: Optional[datetime] = None,
    ) -> Dict[str, Any]:
        """Get audit statistics for a time period."""
        events = self.query_events(
            start_time=start_time,
            end_time=end_time,
        )

        stats = {
            "total_events": len(events),
            "by_type": {},
            "by_severity": {},
            "by_agent": {},
        }

        for event in events:
            # By type
            type_name = event.event_type.value
            stats["by_type"][type_name] = stats["by_type"].get(type_name, 0) + 1

            # By severity
            severity_name = event.severity.value
            stats["by_severity"][severity_name] = stats["by_severity"].get(severity_name, 0) + 1

            # By agent
            if event.agent_id:
                stats["by_agent"][event.agent_id] = stats["by_agent"].get(event.agent_id, 0) + 1

        return stats


# Singleton instance
_audit_trail: Optional[AuditTrail] = None


def get_audit_trail() -> Optional[AuditTrail]:
    """Get the global audit trail instance."""
    return _audit_trail


def set_audit_trail(audit: AuditTrail) -> None:
    """Set the global audit trail instance."""
    global _audit_trail
    _audit_trail = audit
