"""
Tests for User Interaction Detection and Auto-Response handling.
"""

import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch, AsyncMock

from agent_orchestrator.interaction import (
    InteractionType,
    RiskLevel,
    DetectedInteraction,
    InteractionDetector,
    get_interaction_detector,
    set_interaction_detector,
    ResponseAction,
    EscalationLevel,
    ResponseDecision,
    ResponsePolicy,
    AutoResponseHandler,
    InteractionRouter,
    get_auto_handler,
    set_auto_handler,
    get_interaction_router,
    set_interaction_router,
)
from agent_orchestrator.tracking import (
    CLIStateSnapshot,
    SessionState,
    WaitingState,
)


class TestInteractionType:
    """Tests for InteractionType enum."""

    def test_interaction_types(self):
        """Test interaction type values."""
        assert InteractionType.APPROVAL.value == "approval"
        assert InteractionType.CLARIFICATION.value == "clarification"
        assert InteractionType.TOOL_AUTHORIZATION.value == "tool_authorization"
        assert InteractionType.FILE_PERMISSION.value == "file_permission"
        assert InteractionType.COMMAND_CONFIRMATION.value == "command_confirmation"
        assert InteractionType.CONTINUE_PROMPT.value == "continue_prompt"
        assert InteractionType.CUSTOM_INPUT.value == "custom_input"
        assert InteractionType.UNKNOWN.value == "unknown"


class TestRiskLevel:
    """Tests for RiskLevel enum."""

    def test_risk_levels(self):
        """Test risk level values."""
        assert RiskLevel.LOW.value == "low"
        assert RiskLevel.MEDIUM.value == "medium"
        assert RiskLevel.HIGH.value == "high"
        assert RiskLevel.CRITICAL.value == "critical"


class TestDetectedInteraction:
    """Tests for DetectedInteraction dataclass."""

    def test_create_interaction(self):
        """Test creating a detected interaction."""
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.APPROVAL,
            prompt_text="Continue? [y/n]",
            risk_level=RiskLevel.LOW,
            suggested_response="y",
        )
        assert interaction.agent_id == "claude-1"
        assert interaction.interaction_type == InteractionType.APPROVAL
        assert interaction.risk_level == RiskLevel.LOW

    def test_to_dict(self):
        """Test to_dict serialization."""
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.APPROVAL,
            prompt_text="Continue? [y/n]",
            risk_level=RiskLevel.LOW,
        )
        data = interaction.to_dict()
        assert data["agent_id"] == "claude-1"
        assert data["interaction_type"] == "approval"
        assert data["risk_level"] == "low"


class TestInteractionDetector:
    """Tests for InteractionDetector."""

    def setup_method(self):
        """Reset detector before each test."""
        set_interaction_detector(None)

    def test_detector_initialization(self):
        """Test detector initialization."""
        detector = InteractionDetector()
        assert detector._running is False

    def test_classify_approval_pattern(self):
        """Test classification of approval patterns."""
        detector = InteractionDetector()

        # Test various approval patterns
        assert detector._classify_interaction("Proceed? [y/n]") == InteractionType.APPROVAL
        assert detector._classify_interaction("Continue? [Y/N]") == InteractionType.APPROVAL
        assert detector._classify_interaction("Do you want to proceed? [y/n]") == InteractionType.APPROVAL
        assert detector._classify_interaction("Are you sure?") == InteractionType.APPROVAL

    def test_classify_tool_authorization(self):
        """Test classification of tool authorization patterns."""
        detector = InteractionDetector()

        assert detector._classify_interaction("Allow bash command?") == InteractionType.TOOL_AUTHORIZATION
        assert detector._classify_interaction("Run command: ls -la?") == InteractionType.TOOL_AUTHORIZATION
        assert detector._classify_interaction("Tool use: git status") == InteractionType.TOOL_AUTHORIZATION

    def test_classify_file_permission(self):
        """Test classification of file permission patterns."""
        detector = InteractionDetector()

        assert detector._classify_interaction("Create file test.py?") == InteractionType.FILE_PERMISSION
        assert detector._classify_interaction("Modify file config.json?") == InteractionType.FILE_PERMISSION
        assert detector._classify_interaction("Overwrite existing.txt?") == InteractionType.FILE_PERMISSION

    def test_classify_continue_prompt(self):
        """Test classification of continue prompts."""
        detector = InteractionDetector()

        assert detector._classify_interaction("Press enter to continue") == InteractionType.CONTINUE_PROMPT
        assert detector._classify_interaction("Press return to proceed") == InteractionType.CONTINUE_PROMPT

    def test_classify_unknown(self):
        """Test classification of unknown prompts."""
        detector = InteractionDetector()

        assert detector._classify_interaction("Some random text") == InteractionType.UNKNOWN

    def test_assess_risk_dangerous_patterns(self):
        """Test risk assessment for dangerous patterns."""
        detector = InteractionDetector()

        assert detector._assess_risk("rm -rf /", InteractionType.COMMAND_CONFIRMATION) == RiskLevel.CRITICAL
        assert detector._assess_risk("delete all files", InteractionType.COMMAND_CONFIRMATION) == RiskLevel.CRITICAL
        assert detector._assess_risk("DROP DATABASE users", InteractionType.COMMAND_CONFIRMATION) == RiskLevel.CRITICAL
        assert detector._assess_risk("git push --force", InteractionType.COMMAND_CONFIRMATION) == RiskLevel.CRITICAL
        assert detector._assess_risk("sudo rm", InteractionType.COMMAND_CONFIRMATION) == RiskLevel.CRITICAL
        assert detector._assess_risk("deploy to production", InteractionType.COMMAND_CONFIRMATION) == RiskLevel.CRITICAL

    def test_assess_risk_by_interaction_type(self):
        """Test risk assessment by interaction type."""
        detector = InteractionDetector()

        assert detector._assess_risk("Press enter", InteractionType.CONTINUE_PROMPT) == RiskLevel.LOW
        assert detector._assess_risk("Confirm?", InteractionType.APPROVAL) == RiskLevel.MEDIUM
        assert detector._assess_risk("Allow tool?", InteractionType.TOOL_AUTHORIZATION) == RiskLevel.HIGH
        assert detector._assess_risk("What?", InteractionType.UNKNOWN) == RiskLevel.HIGH

    def test_suggest_response_continue(self):
        """Test response suggestion for continue prompts."""
        detector = InteractionDetector()

        response = detector._suggest_response("Press enter", InteractionType.CONTINUE_PROMPT)
        assert response == ""

    def test_suggest_response_approval(self):
        """Test response suggestion for safe approval."""
        detector = InteractionDetector()

        response = detector._suggest_response("Continue? [y/n]", InteractionType.APPROVAL)
        assert response == "y"

    def test_no_suggestion_for_dangerous(self):
        """Test no suggestion for dangerous commands."""
        detector = InteractionDetector()

        response = detector._suggest_response("rm -rf /? [y/n]", InteractionType.APPROVAL)
        assert response is None

    def test_check_agent_not_waiting(self):
        """Test check_agent when agent is not waiting."""
        now = datetime.now()
        session = SessionState(
            session_id="test",
            started_at=now,
            last_activity_at=now,
            is_active=True,
            waiting_for_input=False,
        )
        snapshot = CLIStateSnapshot(
            agent_type="claude",
            agent_id="claude-1",
            is_installed=True,
            state_dir_exists=True,
            session=session,
        )
        reader = MagicMock()
        reader.get_snapshot.return_value = snapshot

        detector = InteractionDetector(readers={"claude": reader})
        interaction = detector.check_agent("claude")
        assert interaction is None

    def test_check_agent_waiting(self):
        """Test check_agent when agent is waiting."""
        now = datetime.now()
        session = SessionState(
            session_id="test",
            started_at=now,
            last_activity_at=now,
            is_active=True,
            waiting_for_input=True,
            waiting_state=WaitingState.APPROVAL,
            input_prompt="Continue? [y/n]",
        )
        snapshot = CLIStateSnapshot(
            agent_type="claude",
            agent_id="claude-1",
            is_installed=True,
            state_dir_exists=True,
            session=session,
        )
        reader = MagicMock()
        reader.get_snapshot.return_value = snapshot

        detector = InteractionDetector(readers={"claude": reader})
        interaction = detector.check_agent("claude")

        assert interaction is not None
        assert interaction.agent_id == "claude"
        assert interaction.interaction_type == InteractionType.APPROVAL

    def test_check_all_agents(self):
        """Test check_all_agents."""
        now = datetime.now()
        session = SessionState(
            session_id="test",
            started_at=now,
            last_activity_at=now,
            is_active=True,
            waiting_for_input=True,
            waiting_state=WaitingState.CLARIFICATION,
            input_prompt="Enter value:",
        )
        snapshot = CLIStateSnapshot(
            agent_type="claude",
            agent_id="claude-1",
            is_installed=True,
            state_dir_exists=True,
            session=session,
        )
        reader = MagicMock()
        reader.get_snapshot.return_value = snapshot

        detector = InteractionDetector(readers={"claude": reader})
        interactions = detector.check_all_agents()

        assert len(interactions) == 1

    def test_callback_registration(self):
        """Test callback registration."""
        callback_calls = []

        def callback(interaction):
            callback_calls.append(interaction)

        detector = InteractionDetector()
        detector.on_interaction(callback)
        assert len(detector._callbacks) == 1

    def test_global_detector_instance(self):
        """Test global detector instance."""
        detector1 = get_interaction_detector()
        detector2 = get_interaction_detector()
        assert detector1 is detector2


class TestResponseAction:
    """Tests for ResponseAction enum."""

    def test_response_actions(self):
        """Test response action values."""
        assert ResponseAction.AUTO_RESPOND.value == "auto_respond"
        assert ResponseAction.ESCALATE.value == "escalate"
        assert ResponseAction.DEFER.value == "defer"
        assert ResponseAction.REJECT.value == "reject"
        assert ResponseAction.SKIP.value == "skip"


class TestEscalationLevel:
    """Tests for EscalationLevel enum."""

    def test_escalation_levels(self):
        """Test escalation level values."""
        assert EscalationLevel.NONE.value == "none"
        assert EscalationLevel.NOTIFY.value == "notify"
        assert EscalationLevel.REQUIRE_APPROVAL.value == "require_approval"
        assert EscalationLevel.BLOCK.value == "block"


class TestResponsePolicy:
    """Tests for ResponsePolicy dataclass."""

    def test_default_policy(self):
        """Test default policy values."""
        policy = ResponsePolicy()
        assert policy.auto_respond_max_risk == RiskLevel.LOW
        assert policy.escalate_min_risk == RiskLevel.HIGH
        assert policy.max_auto_responses_per_hour == 100

    def test_custom_policy(self):
        """Test custom policy configuration."""
        policy = ResponsePolicy(
            auto_respond_max_risk=RiskLevel.MEDIUM,
            always_auto_respond=[InteractionType.CONTINUE_PROMPT],
            never_auto_respond=[InteractionType.COMMAND_CONFIRMATION],
        )
        assert policy.auto_respond_max_risk == RiskLevel.MEDIUM
        assert InteractionType.CONTINUE_PROMPT in policy.always_auto_respond

    def test_to_dict(self):
        """Test to_dict serialization."""
        policy = ResponsePolicy()
        data = policy.to_dict()
        assert data["auto_respond_max_risk"] == "low"
        assert data["max_auto_responses_per_hour"] == 100


class TestAutoResponseHandler:
    """Tests for AutoResponseHandler."""

    def setup_method(self):
        """Reset handler before each test."""
        set_auto_handler(None)

    def test_handler_initialization(self):
        """Test handler initialization."""
        handler = AutoResponseHandler()
        assert handler._policy is not None
        assert handler._auto_response_count == 0

    def test_set_policy(self):
        """Test setting custom policy."""
        handler = AutoResponseHandler()
        policy = ResponsePolicy(auto_respond_max_risk=RiskLevel.MEDIUM)
        handler.set_policy(policy)
        assert handler._policy.auto_respond_max_risk == RiskLevel.MEDIUM

    def test_handle_low_risk_auto_response(self):
        """Test handling low risk interaction."""
        handler = AutoResponseHandler()
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.CONTINUE_PROMPT,
            prompt_text="Press enter to continue",
            risk_level=RiskLevel.LOW,
            suggested_response="",
        )

        decision = handler.handle(interaction)
        assert decision.action == ResponseAction.AUTO_RESPOND
        assert decision.escalation_level == EscalationLevel.NONE

    def test_handle_high_risk_escalation(self):
        """Test handling high risk interaction."""
        handler = AutoResponseHandler()
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.COMMAND_CONFIRMATION,
            prompt_text="Run command?",
            risk_level=RiskLevel.HIGH,
        )

        decision = handler.handle(interaction)
        assert decision.action == ResponseAction.ESCALATE
        assert decision.escalation_level == EscalationLevel.REQUIRE_APPROVAL

    def test_handle_critical_risk_block(self):
        """Test handling critical risk interaction."""
        handler = AutoResponseHandler()
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.COMMAND_CONFIRMATION,
            prompt_text="rm -rf /?",
            risk_level=RiskLevel.CRITICAL,
        )

        decision = handler.handle(interaction)
        assert decision.action == ResponseAction.REJECT
        assert decision.escalation_level == EscalationLevel.BLOCK

    def test_always_escalate_policy(self):
        """Test always_escalate policy."""
        policy = ResponsePolicy(
            always_escalate=[InteractionType.FILE_PERMISSION],
        )
        handler = AutoResponseHandler(policy=policy)
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.FILE_PERMISSION,
            prompt_text="Create file?",
            risk_level=RiskLevel.LOW,
            suggested_response="y",
        )

        decision = handler.handle(interaction)
        assert decision.action == ResponseAction.ESCALATE

    def test_never_auto_respond_policy(self):
        """Test never_auto_respond policy."""
        policy = ResponsePolicy(
            never_auto_respond=[InteractionType.APPROVAL],
        )
        handler = AutoResponseHandler(policy=policy)
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.APPROVAL,
            prompt_text="Continue? [y/n]",
            risk_level=RiskLevel.LOW,
            suggested_response="y",
        )

        decision = handler.handle(interaction)
        assert decision.action == ResponseAction.ESCALATE

    def test_rate_limiting(self):
        """Test rate limiting of auto-responses."""
        policy = ResponsePolicy(max_auto_responses_per_hour=2)
        handler = AutoResponseHandler(policy=policy)

        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.CONTINUE_PROMPT,
            prompt_text="Press enter",
            risk_level=RiskLevel.LOW,
            suggested_response="",
        )

        # First two should succeed
        handler.handle(interaction)
        handler.handle(interaction)

        # Third should be deferred
        decision = handler.handle(interaction)
        assert decision.action == ResponseAction.DEFER
        assert "rate limit" in decision.reason.lower()

    def test_get_history(self):
        """Test getting response history."""
        handler = AutoResponseHandler()
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.CONTINUE_PROMPT,
            prompt_text="Press enter",
            risk_level=RiskLevel.LOW,
            suggested_response="",
        )

        handler.handle(interaction)
        history = handler.get_history()
        assert len(history) == 1

    def test_get_history_filtered(self):
        """Test getting filtered history."""
        handler = AutoResponseHandler()

        interaction1 = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.CONTINUE_PROMPT,
            prompt_text="Press enter",
            risk_level=RiskLevel.LOW,
            suggested_response="",
        )
        interaction2 = DetectedInteraction(
            agent_id="claude-2",
            agent_type="claude",
            interaction_type=InteractionType.COMMAND_CONFIRMATION,
            prompt_text="Run command?",
            risk_level=RiskLevel.HIGH,
        )

        handler.handle(interaction1)
        handler.handle(interaction2)

        history = handler.get_history(agent_id="claude-1")
        assert len(history) == 1
        assert history[0].interaction.agent_id == "claude-1"

    def test_get_stats(self):
        """Test getting handler statistics."""
        handler = AutoResponseHandler()
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.CONTINUE_PROMPT,
            prompt_text="Press enter",
            risk_level=RiskLevel.LOW,
            suggested_response="",
        )

        handler.handle(interaction)
        stats = handler.get_stats()
        assert stats["total_decisions"] == 1
        assert "by_action" in stats

    def test_callback_on_action(self):
        """Test action callbacks."""
        callback_calls = []

        def callback(decision):
            callback_calls.append(decision)

        handler = AutoResponseHandler()
        handler.on_action(ResponseAction.AUTO_RESPOND, callback)

        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.CONTINUE_PROMPT,
            prompt_text="Press enter",
            risk_level=RiskLevel.LOW,
            suggested_response="",
        )

        handler.handle(interaction)
        assert len(callback_calls) == 1

    def test_escalation_callback(self):
        """Test escalation callbacks."""
        callback_calls = []

        def callback(decision):
            callback_calls.append(decision)

        handler = AutoResponseHandler()
        handler.on_escalation(callback)

        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.COMMAND_CONFIRMATION,
            prompt_text="Run command?",
            risk_level=RiskLevel.HIGH,
        )

        handler.handle(interaction)
        assert len(callback_calls) == 1

    def test_global_handler_instance(self):
        """Test global handler instance."""
        handler1 = get_auto_handler()
        handler2 = get_auto_handler()
        assert handler1 is handler2


class TestResponseDecision:
    """Tests for ResponseDecision dataclass."""

    def test_create_decision(self):
        """Test creating response decision."""
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.APPROVAL,
            prompt_text="Continue?",
            risk_level=RiskLevel.LOW,
        )
        decision = ResponseDecision(
            interaction=interaction,
            action=ResponseAction.AUTO_RESPOND,
            response_text="y",
            escalation_level=EscalationLevel.NONE,
            reason="Low risk approval",
        )
        assert decision.action == ResponseAction.AUTO_RESPOND
        assert decision.response_text == "y"

    def test_to_dict(self):
        """Test to_dict serialization."""
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.APPROVAL,
            prompt_text="Continue?",
            risk_level=RiskLevel.LOW,
        )
        decision = ResponseDecision(
            interaction=interaction,
            action=ResponseAction.AUTO_RESPOND,
            response_text="y",
            escalation_level=EscalationLevel.NONE,
            reason="Low risk approval",
        )
        data = decision.to_dict()
        assert data["action"] == "auto_respond"
        assert data["response_text"] == "y"


class TestInteractionRouter:
    """Tests for InteractionRouter."""

    def setup_method(self):
        """Reset router before each test."""
        set_interaction_router(None)

    def test_router_initialization(self):
        """Test router initialization."""
        router = InteractionRouter()
        assert router._running is False

    def test_process_single_no_interaction(self):
        """Test process_single with no interaction."""
        detector = MagicMock()
        detector.check_agent.return_value = None

        router = InteractionRouter(detector=detector)
        decision = router.process_single("claude-1")
        assert decision is None

    def test_process_single_with_interaction(self):
        """Test process_single with interaction."""
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.CONTINUE_PROMPT,
            prompt_text="Press enter",
            risk_level=RiskLevel.LOW,
            suggested_response="",
        )
        detector = MagicMock()
        detector.check_agent.return_value = interaction

        router = InteractionRouter(detector=detector)
        decision = router.process_single("claude-1")

        assert decision is not None
        assert decision.action == ResponseAction.AUTO_RESPOND

    def test_on_auto_response_callback(self):
        """Test auto-response callback registration."""
        callback_calls = []

        def callback(agent_id, response):
            callback_calls.append((agent_id, response))

        router = InteractionRouter()
        router.on_auto_response(callback)
        assert router._auto_response_callback is callback

    def test_on_escalation_callback(self):
        """Test escalation callback registration."""
        callback = MagicMock()
        router = InteractionRouter()
        router.on_escalation(callback)
        assert router._escalation_callback is callback

    def test_global_router_instance(self):
        """Test global router instance."""
        router1 = get_interaction_router()
        router2 = get_interaction_router()
        assert router1 is router2

    @pytest.mark.asyncio
    async def test_route_auto_response_decision(self):
        """Test routing auto-response decision."""
        callback_calls = []

        def callback(agent_id, response):
            callback_calls.append((agent_id, response))

        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.CONTINUE_PROMPT,
            prompt_text="Press enter",
            risk_level=RiskLevel.LOW,
            suggested_response="",
        )
        decision = ResponseDecision(
            interaction=interaction,
            action=ResponseAction.AUTO_RESPOND,
            response_text="",
            escalation_level=EscalationLevel.NONE,
            reason="Test",
        )

        router = InteractionRouter()
        router.on_auto_response(callback)
        await router._route_decision(decision)

        assert len(callback_calls) == 1
        assert callback_calls[0] == ("claude-1", "")

    @pytest.mark.asyncio
    async def test_route_escalation_decision(self):
        """Test routing escalation decision."""
        callback_calls = []

        def callback(decision):
            callback_calls.append(decision)

        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.COMMAND_CONFIRMATION,
            prompt_text="Run command?",
            risk_level=RiskLevel.HIGH,
        )
        decision = ResponseDecision(
            interaction=interaction,
            action=ResponseAction.ESCALATE,
            response_text=None,
            escalation_level=EscalationLevel.REQUIRE_APPROVAL,
            reason="Test",
        )

        router = InteractionRouter()
        router.on_escalation(callback)
        await router._route_decision(decision)

        assert len(callback_calls) == 1
