"""
Tests for Dashboard REST API.
"""

import pytest
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch

from fastapi.testclient import TestClient

from agent_orchestrator.api import create_dashboard_api, get_dashboard_app, set_dashboard_app
from agent_orchestrator.tracking import (
    CLIStateSnapshot,
    RateLimitState,
    SessionState,
    WindowType,
    WaitingState,
    set_rate_limit_monitor,
)
from agent_orchestrator.subscriptions import (
    SubscriptionTier,
    Provider,
    set_subscription_manager,
    SubscriptionManager,
)
from agent_orchestrator.interaction import (
    set_interaction_detector,
    set_auto_handler,
    InteractionType,
    RiskLevel,
    DetectedInteraction,
    ResponseAction,
    ResponseDecision,
    EscalationLevel,
)


@pytest.fixture
def client():
    """Create test client."""
    app = create_dashboard_api()
    return TestClient(app)


def create_mock_snapshot(agent_type="claude", agent_id="claude-1", with_rate_limit=True):
    """Create a mock CLI state snapshot for testing."""
    rate_limit = None
    if with_rate_limit:
        rate_limit = RateLimitState(
            requests_used=50,
            requests_limit=100,
            reset_at=datetime.now() + timedelta(hours=1),
            window_type=WindowType.FIVE_HOUR,
        )

    session = SessionState(
        session_id="test-session",
        started_at=datetime.now(),
        last_activity_at=datetime.now(),
        is_active=True,
        waiting_for_input=False,
        waiting_state=WaitingState.IDLE,
    )

    return CLIStateSnapshot(
        agent_type=agent_type,
        agent_id=agent_id,
        is_installed=True,
        state_dir_exists=True,
        rate_limit=rate_limit,
        session=session,
    )


class TestHealthEndpoint:
    """Tests for health check endpoint."""

    def test_health_check(self, client):
        """Test health check returns healthy."""
        response = client.get("/api/health")
        assert response.status_code == 200
        data = response.json()
        assert data["status"] == "healthy"
        assert "timestamp" in data


class TestTiersEndpoint:
    """Tests for tiers endpoint."""

    def test_list_tiers(self, client):
        """Test listing available tiers."""
        response = client.get("/api/tiers")
        assert response.status_code == 200
        data = response.json()
        assert "anthropic" in data
        assert "openai" in data
        assert "google" in data
        assert "claude_pro" in data["anthropic"]


class TestAgentsEndpoints:
    """Tests for agent endpoints."""

    @patch("agent_orchestrator.api.dashboard.get_all_readers")
    @patch("agent_orchestrator.api.dashboard.get_subscription_manager")
    def test_list_agents(self, mock_sub_mgr, mock_readers, client):
        """Test listing agents."""
        snapshot = create_mock_snapshot()
        reader = MagicMock()
        reader.get_snapshot.return_value = snapshot
        mock_readers.return_value = {"claude": reader}

        sub_mgr = MagicMock()
        sub_mgr.get_subscription.return_value = None
        mock_sub_mgr.return_value = sub_mgr

        response = client.get("/api/agents")
        assert response.status_code == 200
        data = response.json()
        assert isinstance(data, list)

    @patch("agent_orchestrator.api.dashboard.get_all_readers")
    @patch("agent_orchestrator.api.dashboard.get_subscription_manager")
    def test_get_agent_not_found(self, mock_sub_mgr, mock_readers, client):
        """Test getting non-existent agent."""
        mock_readers.return_value = {}
        response = client.get("/api/agents/unknown")
        assert response.status_code == 404


class TestRateLimitEndpoints:
    """Tests for rate limit endpoints."""

    @patch("agent_orchestrator.api.dashboard.get_all_readers")
    def test_list_rate_limits(self, mock_readers, client):
        """Test listing rate limits."""
        snapshot = create_mock_snapshot()
        reader = MagicMock()
        reader.get_snapshot.return_value = snapshot
        mock_readers.return_value = {"claude": reader}

        response = client.get("/api/rate-limits")
        assert response.status_code == 200
        data = response.json()
        assert isinstance(data, list)
        if len(data) > 0:
            assert "current_usage" in data[0]

    @patch("agent_orchestrator.api.dashboard.get_all_readers")
    def test_get_rate_limit(self, mock_readers, client):
        """Test getting specific rate limit."""
        snapshot = create_mock_snapshot()
        reader = MagicMock()
        reader.get_snapshot.return_value = snapshot
        mock_readers.return_value = {"claude": reader}

        response = client.get("/api/rate-limits/claude")
        assert response.status_code == 200
        data = response.json()
        assert data["current_usage"] == 50
        assert data["limit"] == 100

    @patch("agent_orchestrator.api.dashboard.get_all_readers")
    def test_get_rate_limit_not_found(self, mock_readers, client):
        """Test getting rate limit for unknown agent."""
        mock_readers.return_value = {}
        response = client.get("/api/rate-limits/unknown")
        assert response.status_code == 404

    @patch("agent_orchestrator.api.dashboard.get_rate_limit_monitor")
    def test_get_rate_limit_alerts(self, mock_monitor, client):
        """Test getting rate limit alerts."""
        monitor = MagicMock()
        monitor.get_recent_alerts.return_value = []
        mock_monitor.return_value = monitor
        response = client.get("/api/rate-limits/alerts")
        assert response.status_code == 200
        assert isinstance(response.json(), list)


class TestSubscriptionEndpoints:
    """Tests for subscription endpoints."""

    @patch("agent_orchestrator.api.dashboard.get_subscription_manager")
    def test_list_subscriptions(self, mock_mgr, client):
        """Test listing subscriptions."""
        mgr = MagicMock()
        mgr.list_agents.return_value = []
        mock_mgr.return_value = mgr

        response = client.get("/api/subscriptions")
        assert response.status_code == 200
        assert isinstance(response.json(), list)

    @patch("agent_orchestrator.api.dashboard.get_subscription_manager")
    def test_get_subscription_not_found(self, mock_mgr, client):
        """Test getting non-existent subscription."""
        mgr = MagicMock()
        mgr.get_subscription.return_value = None
        mock_mgr.return_value = mgr

        response = client.get("/api/subscriptions/unknown")
        assert response.status_code == 404

    @patch("agent_orchestrator.api.dashboard.get_subscription_manager")
    def test_register_subscription(self, mock_mgr, client):
        """Test registering new subscription."""
        mgr = SubscriptionManager()
        mock_mgr.return_value = mgr

        response = client.post("/api/subscriptions", json={
            "agent_id": "claude-1",
            "tier": "claude_pro",
            "account_email": "test@example.com",
        })
        assert response.status_code == 200
        data = response.json()
        assert data["agent_id"] == "claude-1"
        assert data["tier"] == "claude_pro"

    @patch("agent_orchestrator.api.dashboard.get_subscription_manager")
    def test_register_subscription_invalid_tier(self, mock_mgr, client):
        """Test registering with invalid tier."""
        mgr = SubscriptionManager()
        mock_mgr.return_value = mgr

        response = client.post("/api/subscriptions", json={
            "agent_id": "claude-1",
            "tier": "invalid_tier",
        })
        assert response.status_code == 400

    @patch("agent_orchestrator.api.dashboard.get_subscription_manager")
    def test_update_tier(self, mock_mgr, client):
        """Test updating subscription tier."""
        mgr = SubscriptionManager()
        mgr.register("claude-1", SubscriptionTier.CLAUDE_PRO)
        mock_mgr.return_value = mgr

        response = client.put("/api/subscriptions/claude-1/tier", json={
            "tier": "claude_max",
        })
        assert response.status_code == 200
        data = response.json()
        assert data["tier"] == "claude_max"

    @patch("agent_orchestrator.api.dashboard.get_subscription_manager")
    def test_delete_subscription(self, mock_mgr, client):
        """Test deleting subscription."""
        mgr = SubscriptionManager()
        mgr.register("claude-1", SubscriptionTier.CLAUDE_PRO)
        mock_mgr.return_value = mgr

        response = client.delete("/api/subscriptions/claude-1")
        assert response.status_code == 200

    @patch("agent_orchestrator.api.dashboard.get_subscription_manager")
    def test_delete_subscription_not_found(self, mock_mgr, client):
        """Test deleting non-existent subscription."""
        mgr = SubscriptionManager()
        mock_mgr.return_value = mgr

        response = client.delete("/api/subscriptions/unknown")
        assert response.status_code == 404

    @patch("agent_orchestrator.api.dashboard.get_tier_detector")
    def test_detect_tier(self, mock_detector, client):
        """Test tier auto-detection."""
        detector = MagicMock()
        detector.detect_from_cli_state.return_value = (SubscriptionTier.CLAUDE_PRO, 0.85)
        mock_detector.return_value = detector

        response = client.post("/api/subscriptions/claude-1/detect-tier")
        assert response.status_code == 200
        data = response.json()
        assert data["detected_tier"] == "claude_pro"
        assert data["confidence"] == 0.85


class TestInteractionEndpoints:
    """Tests for interaction endpoints."""

    @patch("agent_orchestrator.api.dashboard.get_interaction_detector")
    def test_list_interactions(self, mock_detector, client):
        """Test listing pending interactions."""
        detector = MagicMock()
        detector.check_all_agents.return_value = []
        mock_detector.return_value = detector

        response = client.get("/api/interactions")
        assert response.status_code == 200
        assert isinstance(response.json(), list)

    @patch("agent_orchestrator.api.dashboard.get_interaction_detector")
    def test_get_interaction_not_found(self, mock_detector, client):
        """Test getting non-existent interaction."""
        detector = MagicMock()
        detector.check_agent.return_value = None
        mock_detector.return_value = detector

        response = client.get("/api/interactions/unknown")
        assert response.status_code == 404

    @patch("agent_orchestrator.api.dashboard.get_interaction_detector")
    def test_get_interaction(self, mock_detector, client):
        """Test getting pending interaction."""
        interaction = DetectedInteraction(
            agent_id="claude-1",
            agent_type="claude",
            interaction_type=InteractionType.APPROVAL,
            prompt_text="Continue? [y/n]",
            risk_level=RiskLevel.LOW,
        )
        detector = MagicMock()
        detector.check_agent.return_value = interaction
        mock_detector.return_value = detector

        response = client.get("/api/interactions/claude-1")
        assert response.status_code == 200
        data = response.json()
        assert data["agent_id"] == "claude-1"
        assert data["interaction_type"] == "approval"

    @patch("agent_orchestrator.api.dashboard.get_auto_handler")
    def test_get_interaction_history(self, mock_handler, client):
        """Test getting interaction history."""
        handler = MagicMock()
        handler.get_history.return_value = []
        mock_handler.return_value = handler

        response = client.get("/api/interactions/history")
        assert response.status_code == 200
        assert isinstance(response.json(), list)

    @patch("agent_orchestrator.api.dashboard.get_auto_handler")
    def test_get_interaction_history_invalid_action(self, mock_handler, client):
        """Test getting history with invalid action filter."""
        handler = MagicMock()
        mock_handler.return_value = handler
        response = client.get("/api/interactions/history?action=invalid")
        assert response.status_code == 400


class TestStatsEndpoints:
    """Tests for statistics endpoints."""

    @patch("agent_orchestrator.api.dashboard.get_all_readers")
    @patch("agent_orchestrator.api.dashboard.get_subscription_manager")
    @patch("agent_orchestrator.api.dashboard.get_interaction_detector")
    @patch("agent_orchestrator.api.dashboard.get_auto_handler")
    def test_get_stats(self, mock_handler, mock_detector, mock_sub_mgr, mock_readers, client):
        """Test getting overall stats."""
        mock_readers.return_value = {}
        mock_sub_mgr.return_value.get_summary.return_value = {
            "by_provider": {},
            "by_tier": {},
        }
        mock_detector.return_value.check_all_agents.return_value = []
        mock_handler.return_value.get_stats.return_value = {
            "auto_responses_this_hour": 0,
        }

        response = client.get("/api/stats")
        assert response.status_code == 200
        data = response.json()
        assert "total_agents" in data
        assert "active_agents" in data

    @patch("agent_orchestrator.api.dashboard.get_rate_limit_monitor")
    def test_get_rate_limit_stats(self, mock_monitor, client):
        """Test getting rate limit stats."""
        mock_monitor.return_value.get_health_summary.return_value = {
            "agents": {},
            "alerts_count": 0,
        }

        response = client.get("/api/stats/rate-limits")
        assert response.status_code == 200

    @patch("agent_orchestrator.api.dashboard.get_auto_handler")
    def test_get_interaction_stats(self, mock_handler, client):
        """Test getting interaction stats."""
        mock_handler.return_value.get_stats.return_value = {
            "total_decisions": 0,
            "by_action": {},
        }

        response = client.get("/api/stats/interactions")
        assert response.status_code == 200

    @patch("agent_orchestrator.api.dashboard.get_subscription_manager")
    def test_get_subscription_stats(self, mock_mgr, client):
        """Test getting subscription stats."""
        mock_mgr.return_value.get_summary.return_value = {
            "total_agents": 0,
            "by_tier": {},
        }

        response = client.get("/api/stats/subscriptions")
        assert response.status_code == 200


class TestGlobalAppInstance:
    """Tests for global app instance management."""

    def test_get_dashboard_app(self):
        """Test getting global dashboard app."""
        set_dashboard_app(None)
        app1 = get_dashboard_app()
        app2 = get_dashboard_app()
        assert app1 is app2

    def test_set_dashboard_app(self):
        """Test setting custom dashboard app."""
        custom_app = create_dashboard_api()
        set_dashboard_app(custom_app)
        assert get_dashboard_app() is custom_app


class TestCORSMiddleware:
    """Tests for CORS middleware."""

    def test_cors_headers(self, client):
        """Test CORS headers are present."""
        response = client.options("/api/health", headers={
            "Origin": "http://localhost:3000",
            "Access-Control-Request-Method": "GET",
        })
        # FastAPI/Starlette handles preflight requests
        assert response.status_code in [200, 400]  # Depends on CORS config
