"""
Graceful Shutdown - Clean shutdown handling for the orchestrator.

Implements:
- Signal handling (SIGINT, SIGTERM)
- In-flight work completion
- Resource cleanup
- State persistence on shutdown
"""

from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional, Dict, List, Any, Callable, Set
import asyncio
import signal
import sys
import logging

logger = logging.getLogger(__name__)


class ShutdownState:
    """Current shutdown state."""
    RUNNING = "running"
    SHUTTING_DOWN = "shutting_down"
    SHUTDOWN = "shutdown"


@dataclass
class ShutdownConfig:
    """Configuration for graceful shutdown."""

    # Maximum time to wait for in-flight work
    grace_period_seconds: float = 30.0

    # Time between shutdown progress checks
    check_interval_seconds: float = 1.0

    # Whether to force shutdown after grace period
    force_after_grace_period: bool = True

    # Whether to persist state on shutdown
    persist_state: bool = True

    # Priority for cleanup tasks (lower = earlier)
    cleanup_priorities: Dict[str, int] = field(default_factory=lambda: {
        "agents": 1,      # Stop agents first
        "tasks": 2,       # Cancel pending tasks
        "connections": 3, # Close connections
        "state": 4,       # Persist state
        "logging": 5,     # Flush logs last
    })


@dataclass
class ShutdownTask:
    """A task to run during shutdown."""

    name: str
    callback: Callable
    priority: int = 10
    timeout_seconds: float = 10.0
    required: bool = False  # If True, failure blocks shutdown


@dataclass
class ShutdownResult:
    """Result of shutdown process."""

    success: bool
    started_at: datetime
    completed_at: Optional[datetime] = None
    duration_seconds: float = 0.0

    # Task results
    tasks_completed: int = 0
    tasks_failed: int = 0
    tasks_skipped: int = 0

    # Details
    task_results: Dict[str, str] = field(default_factory=dict)
    errors: List[str] = field(default_factory=list)

    # Whether forced
    was_forced: bool = False


class GracefulShutdown:
    """
    Manages graceful shutdown of the orchestrator.

    Handles:
    - Signal registration (SIGINT, SIGTERM)
    - Running cleanup tasks in priority order
    - Completing in-flight work
    - State persistence
    """

    def __init__(
        self,
        config: Optional[ShutdownConfig] = None,
    ):
        """
        Initialize shutdown manager.

        Args:
            config: Shutdown configuration
        """
        self.config = config or ShutdownConfig()
        self._state = ShutdownState.RUNNING
        self._shutdown_event = asyncio.Event()
        self._shutdown_tasks: List[ShutdownTask] = []
        self._in_flight: Set[str] = set()
        self._shutdown_result: Optional[ShutdownResult] = None

        # Callbacks
        self._pre_shutdown_callbacks: List[Callable] = []
        self._post_shutdown_callbacks: List[Callable] = []

    @property
    def state(self) -> str:
        """Get current shutdown state."""
        return self._state

    @property
    def is_shutting_down(self) -> bool:
        """Check if shutdown is in progress."""
        return self._state != ShutdownState.RUNNING

    def register_signals(self) -> None:
        """Register signal handlers for graceful shutdown."""
        loop = asyncio.get_event_loop()

        for sig in (signal.SIGINT, signal.SIGTERM):
            try:
                loop.add_signal_handler(
                    sig,
                    lambda s=sig: asyncio.create_task(self._handle_signal(s))
                )
                logger.debug(f"Registered handler for {sig.name}")
            except NotImplementedError:
                # Windows doesn't support add_signal_handler
                signal.signal(sig, lambda s, f: asyncio.create_task(self._handle_signal(s)))

    async def _handle_signal(self, sig: signal.Signals) -> None:
        """Handle shutdown signal."""
        logger.info(f"Received signal {sig.name}, initiating graceful shutdown")
        await self.shutdown()

    def register_task(
        self,
        name: str,
        callback: Callable,
        priority: Optional[int] = None,
        timeout_seconds: float = 10.0,
        required: bool = False,
    ) -> None:
        """
        Register a cleanup task to run during shutdown.

        Args:
            name: Task name for logging
            callback: Async callable to run
            priority: Priority (lower = earlier). Uses config defaults if not specified
            timeout_seconds: Max time for this task
            required: If True, failure blocks shutdown
        """
        if priority is None:
            # Try to get from config defaults
            for category, prio in self.config.cleanup_priorities.items():
                if category in name.lower():
                    priority = prio
                    break
            if priority is None:
                priority = 10  # Default priority

        task = ShutdownTask(
            name=name,
            callback=callback,
            priority=priority,
            timeout_seconds=timeout_seconds,
            required=required,
        )
        self._shutdown_tasks.append(task)
        logger.debug(f"Registered shutdown task: {name} (priority={priority})")

    def on_pre_shutdown(self, callback: Callable) -> None:
        """Register callback to run before shutdown tasks."""
        self._pre_shutdown_callbacks.append(callback)

    def on_post_shutdown(self, callback: Callable) -> None:
        """Register callback to run after shutdown tasks."""
        self._post_shutdown_callbacks.append(callback)

    def track_work(self, work_id: str) -> None:
        """Track in-flight work item."""
        self._in_flight.add(work_id)

    def complete_work(self, work_id: str) -> None:
        """Mark work item as complete."""
        self._in_flight.discard(work_id)

    def get_in_flight_count(self) -> int:
        """Get number of in-flight work items."""
        return len(self._in_flight)

    async def wait_for_in_flight(self) -> bool:
        """
        Wait for all in-flight work to complete.

        Returns:
            True if all work completed, False if timed out
        """
        start_time = datetime.now()
        deadline = start_time.timestamp() + self.config.grace_period_seconds

        while self._in_flight:
            if datetime.now().timestamp() > deadline:
                logger.warning(
                    f"Timeout waiting for {len(self._in_flight)} in-flight items: "
                    f"{list(self._in_flight)[:5]}"
                )
                return False

            logger.info(f"Waiting for {len(self._in_flight)} in-flight work items...")
            await asyncio.sleep(self.config.check_interval_seconds)

        return True

    async def shutdown(self, force: bool = False) -> ShutdownResult:
        """
        Initiate graceful shutdown.

        Args:
            force: If True, skip waiting for in-flight work

        Returns:
            ShutdownResult with details
        """
        if self._state != ShutdownState.RUNNING:
            logger.warning("Shutdown already in progress")
            # If shutdown is still in progress (not completed), return cached result
            if self._state == ShutdownState.SHUTTING_DOWN and self._shutdown_result:
                return self._shutdown_result
            # Otherwise, return a new failed result indicating shutdown already happened
            return ShutdownResult(
                success=False,
                started_at=datetime.now(),
                errors=["Shutdown already in progress"],
            )

        self._state = ShutdownState.SHUTTING_DOWN
        self._shutdown_event.set()

        result = ShutdownResult(
            success=True,
            started_at=datetime.now(),
        )

        logger.info("Starting graceful shutdown...")

        # Run pre-shutdown callbacks
        for callback in self._pre_shutdown_callbacks:
            try:
                if asyncio.iscoroutinefunction(callback):
                    await callback()
                else:
                    callback()
            except Exception as e:
                logger.error(f"Pre-shutdown callback failed: {e}")
                result.errors.append(f"Pre-shutdown callback: {str(e)}")

        # Wait for in-flight work
        if not force and self._in_flight:
            logger.info(f"Waiting for {len(self._in_flight)} in-flight work items...")
            completed = await self.wait_for_in_flight()

            if not completed and self.config.force_after_grace_period:
                logger.warning("Force completing shutdown after grace period")
                result.was_forced = True

        # Sort tasks by priority
        sorted_tasks = sorted(self._shutdown_tasks, key=lambda t: t.priority)

        # Run shutdown tasks
        for task in sorted_tasks:
            logger.debug(f"Running shutdown task: {task.name}")

            try:
                if asyncio.iscoroutinefunction(task.callback):
                    await asyncio.wait_for(
                        task.callback(),
                        timeout=task.timeout_seconds,
                    )
                else:
                    task.callback()

                result.tasks_completed += 1
                result.task_results[task.name] = "completed"
                logger.debug(f"Shutdown task completed: {task.name}")

            except asyncio.TimeoutError:
                logger.error(f"Shutdown task timed out: {task.name}")
                result.task_results[task.name] = "timeout"
                result.errors.append(f"{task.name}: timeout")

                if task.required:
                    result.tasks_failed += 1
                    result.success = False
                else:
                    result.tasks_skipped += 1

            except Exception as e:
                logger.error(f"Shutdown task failed: {task.name}: {e}")
                result.task_results[task.name] = f"error: {str(e)}"
                result.errors.append(f"{task.name}: {str(e)}")

                if task.required:
                    result.tasks_failed += 1
                    result.success = False
                else:
                    result.tasks_skipped += 1

        # Run post-shutdown callbacks
        for callback in self._post_shutdown_callbacks:
            try:
                if asyncio.iscoroutinefunction(callback):
                    await callback()
                else:
                    callback()
            except Exception as e:
                logger.error(f"Post-shutdown callback failed: {e}")
                result.errors.append(f"Post-shutdown callback: {str(e)}")

        # Finalize
        self._state = ShutdownState.SHUTDOWN
        result.completed_at = datetime.now()
        result.duration_seconds = (
            result.completed_at - result.started_at
        ).total_seconds()

        self._shutdown_result = result

        logger.info(
            f"Shutdown complete in {result.duration_seconds:.2f}s - "
            f"Tasks: {result.tasks_completed} completed, "
            f"{result.tasks_failed} failed, {result.tasks_skipped} skipped"
        )

        return result

    async def wait_for_shutdown(self) -> None:
        """Wait until shutdown is requested."""
        await self._shutdown_event.wait()

    def get_status(self) -> Dict[str, Any]:
        """Get current shutdown status."""
        return {
            "state": self._state,
            "in_flight_count": len(self._in_flight),
            "registered_tasks": len(self._shutdown_tasks),
            "shutdown_result": {
                "success": self._shutdown_result.success,
                "duration_seconds": self._shutdown_result.duration_seconds,
                "tasks_completed": self._shutdown_result.tasks_completed,
                "errors": self._shutdown_result.errors,
            } if self._shutdown_result else None,
        }


# Convenience functions for common shutdown tasks

async def stop_agents(agents: Dict[str, Any]) -> None:
    """Shutdown task: Stop all agents."""
    for agent_id, agent in agents.items():
        try:
            if hasattr(agent, 'stop'):
                await agent.stop()
            elif hasattr(agent, 'cancel'):
                await agent.cancel()
            logger.debug(f"Stopped agent: {agent_id}")
        except Exception as e:
            logger.error(f"Error stopping agent {agent_id}: {e}")


async def cancel_pending_tasks(task_queue: Any) -> None:
    """Shutdown task: Cancel pending tasks."""
    if hasattr(task_queue, 'cancel_all'):
        await task_queue.cancel_all()
    elif hasattr(task_queue, 'clear'):
        task_queue.clear()


async def close_connections(connections: List[Any]) -> None:
    """Shutdown task: Close connections."""
    for conn in connections:
        try:
            if hasattr(conn, 'close'):
                if asyncio.iscoroutinefunction(conn.close):
                    await conn.close()
                else:
                    conn.close()
        except Exception as e:
            logger.error(f"Error closing connection: {e}")


async def persist_state(db: Any, state: Dict[str, Any]) -> None:
    """Shutdown task: Persist state to database."""
    if hasattr(db, 'save_shutdown_state'):
        await db.save_shutdown_state(state)
    logger.info("State persisted for recovery")


async def flush_logs() -> None:
    """Shutdown task: Flush log handlers."""
    for handler in logging.root.handlers:
        try:
            handler.flush()
        except Exception:
            pass


# Singleton instance
_shutdown_manager: Optional[GracefulShutdown] = None


def get_shutdown_manager() -> GracefulShutdown:
    """Get or create the global shutdown manager."""
    global _shutdown_manager
    if _shutdown_manager is None:
        _shutdown_manager = GracefulShutdown()
    return _shutdown_manager


def set_shutdown_manager(manager: GracefulShutdown) -> None:
    """Set the global shutdown manager."""
    global _shutdown_manager
    _shutdown_manager = manager
