"""
Agent Health Check - Stuck detection and health monitoring.

This module provides heuristics to detect when an agent is stuck:
1. Idle but burning tokens (no progress, high cost)
2. Repeated error loops (same error 3+ times)
3. Edit oscillation (undo/redo pattern)
4. Awaiting human decision (blocked on approval)
5. Timeout exceeded (task running too long)

Usage:
    from agent_orchestrator.control.health import AgentHealthCheck

    health_check = AgentHealthCheck(db)
    sample = health_check.sample_agent("claude-code")

    if sample.is_stuck:
        print(f"Agent stuck: {sample.stuck_reason}")
"""

import logging
import re
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Optional

from ..persistence.database import OrchestratorDB
from ..persistence.models import HealthSample


logger = logging.getLogger(__name__)


class StuckReason(Enum):
    """Reasons an agent might be stuck."""

    NOT_STUCK = "not_stuck"
    IDLE_BURNING_TOKENS = "idle_burning_tokens"
    REPEATED_ERROR_LOOP = "repeated_error_loop"
    EDIT_OSCILLATION = "edit_oscillation"
    AWAITING_APPROVAL = "awaiting_approval"
    TIMEOUT_EXCEEDED = "timeout_exceeded"
    NO_PROGRESS = "no_progress"
    RESOURCE_EXHAUSTED = "resource_exhausted"
    UNKNOWN_STATE = "unknown_state"


class AgentState(Enum):
    """Observed agent states."""

    IDLE = "idle"
    RUNNING = "running"
    WAITING_INPUT = "waiting_input"
    WAITING_APPROVAL = "waiting_approval"
    ERROR = "error"
    COMPLETED = "completed"
    TERMINATED = "terminated"
    UNKNOWN = "unknown"


@dataclass
class HealthCheckResult:
    """Result of a health check on an agent."""

    agent_id: str
    timestamp: datetime
    state: AgentState
    is_healthy: bool
    is_stuck: bool
    stuck_reason: StuckReason
    stuck_details: str = ""
    tokens_used_recent: int = 0
    errors_recent: int = 0
    last_output_age_seconds: float = 0
    recommendations: list[str] = field(default_factory=list)
    raw_metrics: dict[str, Any] = field(default_factory=dict)

    def to_health_sample(self) -> HealthSample:
        """Convert to HealthSample for database storage."""
        last_output_at = None
        if self.last_output_age_seconds != float('inf'):
            try:
                last_output_at = datetime.now() - timedelta(seconds=self.last_output_age_seconds)
            except (OverflowError, ValueError):
                last_output_at = None

        return HealthSample(
            agent_id=self.agent_id,
            sampled_at=self.timestamp,
            is_stuck=self.is_stuck,
            stuck_reason=self.stuck_reason.value if self.is_stuck else None,
            token_burn_rate=float(self.tokens_used_recent),
            error_count=self.errors_recent,
            last_stdout_at=last_output_at,
        )


@dataclass
class StuckDetectionConfig:
    """Configuration for stuck detection heuristics."""

    # Idle but burning tokens
    idle_token_threshold: int = 1000  # Tokens with no output
    idle_time_threshold_seconds: int = 300  # 5 minutes

    # Repeated errors
    repeated_error_threshold: int = 3  # Same error N times
    error_window_seconds: int = 600  # Within 10 minutes

    # Edit oscillation
    oscillation_threshold: int = 3  # N undo/redo cycles
    oscillation_window_seconds: int = 300  # Within 5 minutes

    # Timeout
    task_timeout_seconds: int = 3600  # 1 hour default

    # No progress
    no_progress_threshold_seconds: int = 600  # 10 minutes no output

    # Approval timeout
    approval_timeout_seconds: int = 1800  # 30 minutes waiting


class AgentHealthCheck:
    """
    Performs health checks on agents to detect stuck states.

    The health checker samples agent state periodically and
    applies heuristics to detect problematic patterns.
    """

    def __init__(
        self,
        db: OrchestratorDB,
        config: Optional[StuckDetectionConfig] = None,
    ):
        """
        Initialize the health checker.

        Args:
            db: Database for historical data
            config: Detection thresholds
        """
        self.db = db
        self.config = config or StuckDetectionConfig()

        # Recent observations cache
        self._recent_outputs: dict[str, list[tuple[datetime, str]]] = {}
        self._recent_errors: dict[str, list[tuple[datetime, str]]] = {}
        self._recent_edits: dict[str, list[tuple[datetime, str, str]]] = {}

    def sample_agent(
        self,
        agent_id: str,
        current_output: Optional[str] = None,
        tokens_used: int = 0,
        task_start_time: Optional[datetime] = None,
    ) -> HealthCheckResult:
        """
        Sample an agent's health.

        Args:
            agent_id: The agent to check
            current_output: Recent output from the agent
            tokens_used: Tokens used in current session
            task_start_time: When the current task started

        Returns:
            HealthCheckResult with stuck detection
        """
        now = datetime.now()

        # Record output if provided
        if current_output:
            self._record_output(agent_id, current_output)

        # Run all stuck detection heuristics
        stuck_reason = StuckReason.NOT_STUCK
        stuck_details = ""
        recommendations = []

        # Check idle but burning tokens
        idle_result = self._check_idle_burning_tokens(agent_id, tokens_used)
        if idle_result[0]:
            stuck_reason = StuckReason.IDLE_BURNING_TOKENS
            stuck_details = idle_result[1]
            recommendations.append("Consider sending an unstick prompt")

        # Check repeated error loop
        if stuck_reason == StuckReason.NOT_STUCK:
            error_result = self._check_repeated_errors(agent_id)
            if error_result[0]:
                stuck_reason = StuckReason.REPEATED_ERROR_LOOP
                stuck_details = error_result[1]
                recommendations.append("Agent may need different approach")
                recommendations.append("Consider escalating to human")

        # Check edit oscillation
        if stuck_reason == StuckReason.NOT_STUCK:
            oscillation_result = self._check_edit_oscillation(agent_id)
            if oscillation_result[0]:
                stuck_reason = StuckReason.EDIT_OSCILLATION
                stuck_details = oscillation_result[1]
                recommendations.append("Agent is undoing/redoing changes")

        # Check timeout
        if stuck_reason == StuckReason.NOT_STUCK and task_start_time:
            timeout_result = self._check_timeout(task_start_time)
            if timeout_result[0]:
                stuck_reason = StuckReason.TIMEOUT_EXCEEDED
                stuck_details = timeout_result[1]
                recommendations.append("Consider cancelling and reassigning")

        # Check awaiting approval
        if stuck_reason == StuckReason.NOT_STUCK:
            approval_result = self._check_awaiting_approval(agent_id)
            if approval_result[0]:
                stuck_reason = StuckReason.AWAITING_APPROVAL
                stuck_details = approval_result[1]
                recommendations.append("Pending approval needs attention")

        # Check no progress
        if stuck_reason == StuckReason.NOT_STUCK:
            progress_result = self._check_no_progress(agent_id, task_start_time)
            if progress_result[0]:
                stuck_reason = StuckReason.NO_PROGRESS
                stuck_details = progress_result[1]
                recommendations.append("No output for extended period")

        # Determine overall state
        state = self._determine_state(agent_id, stuck_reason)
        is_stuck = stuck_reason != StuckReason.NOT_STUCK
        is_healthy = not is_stuck and state not in {AgentState.ERROR, AgentState.UNKNOWN}

        # Calculate metrics
        last_output_age = self._get_last_output_age(agent_id)
        recent_errors = len(self._recent_errors.get(agent_id, []))

        result = HealthCheckResult(
            agent_id=agent_id,
            timestamp=now,
            state=state,
            is_healthy=is_healthy,
            is_stuck=is_stuck,
            stuck_reason=stuck_reason,
            stuck_details=stuck_details,
            tokens_used_recent=tokens_used,
            errors_recent=recent_errors,
            last_output_age_seconds=last_output_age,
            recommendations=recommendations,
            raw_metrics={
                "tokens_used": tokens_used,
                "recent_outputs": len(self._recent_outputs.get(agent_id, [])),
                "recent_errors": recent_errors,
            },
        )

        # Store in database
        sample = result.to_health_sample()
        self.db.record_health_sample(sample)

        return result

    def _record_output(self, agent_id: str, output: str) -> None:
        """Record an output observation."""
        now = datetime.now()

        if agent_id not in self._recent_outputs:
            self._recent_outputs[agent_id] = []

        self._recent_outputs[agent_id].append((now, output))

        # Keep only recent observations
        cutoff = now - timedelta(seconds=self.config.error_window_seconds)
        self._recent_outputs[agent_id] = [
            (t, o) for t, o in self._recent_outputs[agent_id] if t > cutoff
        ]

        # Check for errors in output
        if self._is_error_output(output):
            if agent_id not in self._recent_errors:
                self._recent_errors[agent_id] = []
            self._recent_errors[agent_id].append((now, output))

            # Clean old errors
            self._recent_errors[agent_id] = [
                (t, e) for t, e in self._recent_errors[agent_id] if t > cutoff
            ]

        # Check for edit patterns
        edit_match = self._extract_edit_pattern(output)
        if edit_match:
            if agent_id not in self._recent_edits:
                self._recent_edits[agent_id] = []
            self._recent_edits[agent_id].append((now, edit_match[0], edit_match[1]))

    def _is_error_output(self, output: str) -> bool:
        """Check if output contains error indicators."""
        error_patterns = [
            r"error:",
            r"Error:",
            r"ERROR",
            r"exception:",
            r"Exception:",
            r"FAILED",
            r"failed:",
            r"Traceback",
            r"SyntaxError",
            r"TypeError",
            r"ValueError",
            r"KeyError",
            r"AttributeError",
            r"ModuleNotFoundError",
        ]

        for pattern in error_patterns:
            if re.search(pattern, output):
                return True
        return False

    def _extract_edit_pattern(self, output: str) -> Optional[tuple[str, str]]:
        """Extract file edit patterns from output."""
        # Look for edit/write file patterns
        edit_patterns = [
            r"(?:Edit|Write|Modified)\s+['\"]?([^'\"]+\.[a-z]+)['\"]?",
            r"wrote\s+to\s+['\"]?([^'\"]+\.[a-z]+)['\"]?",
        ]

        for pattern in edit_patterns:
            match = re.search(pattern, output, re.IGNORECASE)
            if match:
                return (match.group(1), "edit")
        return None

    def _check_idle_burning_tokens(
        self,
        agent_id: str,
        tokens_used: int,
    ) -> tuple[bool, str]:
        """Check for idle but burning tokens pattern."""
        if tokens_used < self.config.idle_token_threshold:
            return False, ""

        last_output_age = self._get_last_output_age(agent_id)

        if last_output_age > self.config.idle_time_threshold_seconds:
            return True, f"Used {tokens_used} tokens but no output for {last_output_age:.0f}s"

        return False, ""

    def _check_repeated_errors(self, agent_id: str) -> tuple[bool, str]:
        """Check for repeated error loop pattern."""
        recent_errors = self._recent_errors.get(agent_id, [])

        if len(recent_errors) < self.config.repeated_error_threshold:
            return False, ""

        # Check if errors are similar
        error_texts = [e[1][:100] for e in recent_errors[-5:]]

        # Simple similarity check - same prefix
        if len(set(error_texts)) <= 2:  # At most 2 unique error types
            return True, f"Same error repeated {len(recent_errors)} times"

        return False, ""

    def _check_edit_oscillation(self, agent_id: str) -> tuple[bool, str]:
        """Check for edit oscillation (undo/redo) pattern."""
        recent_edits = self._recent_edits.get(agent_id, [])

        if len(recent_edits) < self.config.oscillation_threshold * 2:
            return False, ""

        # Check if same file edited multiple times
        files_edited = [e[1] for e in recent_edits]
        file_counts = {}
        for f in files_edited:
            file_counts[f] = file_counts.get(f, 0) + 1

        # If any file edited more than threshold times
        for file, count in file_counts.items():
            if count >= self.config.oscillation_threshold:
                return True, f"File '{file}' edited {count} times (possible oscillation)"

        return False, ""

    def _check_timeout(self, task_start_time: datetime) -> tuple[bool, str]:
        """Check if task has exceeded timeout."""
        elapsed = (datetime.now() - task_start_time).total_seconds()

        if elapsed > self.config.task_timeout_seconds:
            return True, f"Task running for {elapsed/60:.1f} minutes (timeout: {self.config.task_timeout_seconds/60:.0f}m)"

        return False, ""

    def _check_awaiting_approval(self, agent_id: str) -> tuple[bool, str]:
        """Check if agent is waiting on approval for too long."""
        pending = self.db.get_pending_approvals()
        agent_pending = [a for a in pending if a.agent_id == agent_id]

        if not agent_pending:
            return False, ""

        oldest = min(a.requested_at for a in agent_pending)
        wait_time = (datetime.now() - oldest).total_seconds()

        if wait_time > self.config.approval_timeout_seconds:
            return True, f"Waiting {wait_time/60:.1f} minutes for approval"

        return False, ""

    def _check_no_progress(
        self,
        agent_id: str,
        task_start_time: Optional[datetime] = None,
    ) -> tuple[bool, str]:
        """Check for no progress (no output) pattern."""
        last_output_age = self._get_last_output_age(agent_id)

        # If we have a task start time, treat it as the "last output" time
        # if we haven't seen any output yet. This gives a grace period.
        if task_start_time and last_output_age == float('inf'):
            elapsed = (datetime.now() - task_start_time).total_seconds()
            if elapsed < self.config.no_progress_threshold_seconds:
                return False, ""
            return True, f"No output for {elapsed/60:.1f} minutes (since start)"

        if last_output_age > self.config.no_progress_threshold_seconds:
            return True, f"No output for {last_output_age/60:.1f} minutes"

        return False, ""

    def _get_last_output_age(self, agent_id: str) -> float:
        """Get seconds since last output."""
        outputs = self._recent_outputs.get(agent_id, [])

        if not outputs:
            return float('inf')

        last_time = max(t for t, _ in outputs)
        return (datetime.now() - last_time).total_seconds()

    def _determine_state(self, agent_id: str, stuck_reason: StuckReason) -> AgentState:
        """Determine the agent's current state."""
        if stuck_reason == StuckReason.AWAITING_APPROVAL:
            return AgentState.WAITING_APPROVAL

        if stuck_reason in {
            StuckReason.REPEATED_ERROR_LOOP,
            StuckReason.EDIT_OSCILLATION,
        }:
            return AgentState.ERROR

        if stuck_reason == StuckReason.TIMEOUT_EXCEEDED:
            return AgentState.RUNNING  # Still running, just too long

        # Check recent activity
        outputs = self._recent_outputs.get(agent_id, [])
        if outputs:
            last_age = self._get_last_output_age(agent_id)
            if last_age < 60:  # Active in last minute
                return AgentState.RUNNING
            elif last_age < 300:  # Active in last 5 minutes
                return AgentState.IDLE

        return AgentState.UNKNOWN

    def is_stuck(self, agent_id: str, **kwargs) -> bool:
        """Quick check if an agent is stuck."""
        result = self.sample_agent(agent_id, **kwargs)
        return result.is_stuck

    def get_recent_samples(self, agent_id: str, limit: int = 10) -> list[HealthSample]:
        """Get recent health samples for an agent."""
        return self.db.get_health_samples(agent_id, limit=limit)

    def clear_observations(self, agent_id: str) -> None:
        """Clear cached observations for an agent (e.g., after task completion)."""
        self._recent_outputs.pop(agent_id, None)
        self._recent_errors.pop(agent_id, None)
        self._recent_edits.pop(agent_id, None)


def generate_unstick_prompt(stuck_reason: StuckReason, details: str) -> str:
    """
    Generate an unstick prompt based on the stuck reason.

    Args:
        stuck_reason: Why the agent is stuck
        details: Additional context

    Returns:
        Prompt to send to the agent
    """
    prompts = {
        StuckReason.IDLE_BURNING_TOKENS: (
            "You seem to be processing without producing output. "
            "Please provide a status update on what you're working on, "
            "or try a different approach if you're blocked."
        ),
        StuckReason.REPEATED_ERROR_LOOP: (
            f"You've encountered the same error multiple times: {details}\n"
            "Please try a different approach. Consider:\n"
            "1. Reading error messages carefully\n"
            "2. Checking documentation or examples\n"
            "3. Simplifying the problem\n"
            "4. Asking for help if truly stuck"
        ),
        StuckReason.EDIT_OSCILLATION: (
            "You appear to be making and undoing the same changes repeatedly. "
            "Please step back and reconsider your approach. "
            "What is the core issue you're trying to solve?"
        ),
        StuckReason.NO_PROGRESS: (
            "No activity detected for a while. Please provide a status update:\n"
            "- What have you accomplished?\n"
            "- What are you currently working on?\n"
            "- Are you blocked on anything?"
        ),
        StuckReason.TIMEOUT_EXCEEDED: (
            "This task has been running longer than expected. "
            "Please summarize your progress and indicate if you need more time "
            "or if the task should be broken into smaller pieces."
        ),
    }

    return prompts.get(
        stuck_reason,
        "Please provide a status update on your current progress."
    )
