"""
Tests for the budget module.

Tests cover:
- Agent budget enforcement
- MCP server/tool budget tracking
- MCP registry allowlist/blocklist
- Token Audit integration
"""

import pytest
from datetime import date, datetime, timedelta
from unittest.mock import MagicMock, AsyncMock

from agent_orchestrator.budget.agent_budget import (
    AgentBudget,
    DailyUsage,
    BudgetStatus,
    BudgetCheckResult,
    BudgetEnforcer,
    DEFAULT_BUDGETS,
    TOKEN_COSTS,
)

from agent_orchestrator.budget.mcp_budget import (
    MCPServerBudget,
    ToolBudget,
    MCPUsageRecord,
    MCPBudgetStatus,
    MCPBudgetCheckResult,
    MCPUsageTracker,
    DEFAULT_MCP_BUDGETS,
    HIGH_RISK_TOOLS,
)

from agent_orchestrator.budget.registry import (
    MCPRegistry,
    AccessLevel,
    RegistryDecision,
    RegistryCheckResult,
    ServerRegistration,
    ToolRegistration,
    DEFAULT_ALLOWED_SERVERS,
    DEFAULT_BLOCKED_SERVERS,
)


# =============================================================================
# Agent Budget Tests
# =============================================================================

class TestAgentBudget:
    """Tests for AgentBudget dataclass."""

    def test_default_budget_values(self):
        """Test default budget has reasonable values."""
        budget = AgentBudget(agent_id="test-agent")

        assert budget.daily_input_tokens == 1_000_000
        assert budget.daily_output_tokens == 200_000
        assert budget.daily_cost_limit == 50.0
        assert budget.warning_threshold == 0.8
        assert budget.is_paused is False

    def test_custom_budget_values(self):
        """Test custom budget configuration."""
        budget = AgentBudget(
            agent_id="custom-agent",
            daily_input_tokens=500_000,
            daily_output_tokens=100_000,
            daily_cost_limit=25.0,
            max_tokens_per_request=50_000,
        )

        assert budget.daily_input_tokens == 500_000
        assert budget.daily_cost_limit == 25.0
        assert budget.max_tokens_per_request == 50_000

    def test_default_budgets_exist(self):
        """Test that default budgets are defined for known agents."""
        assert "claude-code" in DEFAULT_BUDGETS
        assert "gemini-cli" in DEFAULT_BUDGETS
        assert "codex-cli" in DEFAULT_BUDGETS

    def test_token_costs_defined(self):
        """Test that token costs are defined."""
        assert "claude-code" in TOKEN_COSTS
        assert "default" in TOKEN_COSTS
        assert "input" in TOKEN_COSTS["default"]
        assert "output" in TOKEN_COSTS["default"]


class TestBudgetEnforcer:
    """Tests for BudgetEnforcer class."""

    @pytest.fixture
    def mock_db(self):
        """Create a mock database."""
        db = MagicMock()
        db.get_daily_usage.return_value = None
        return db

    @pytest.fixture
    def enforcer(self, mock_db):
        """Create a budget enforcer with mock DB."""
        return BudgetEnforcer(db=mock_db)

    def test_get_budget_returns_default(self, enforcer):
        """Test getting budget returns default for new agents."""
        budget = enforcer.get_budget("new-agent")

        assert budget.agent_id == "new-agent"
        assert budget.daily_input_tokens == 1_000_000  # Default

    def test_get_budget_returns_known_agent(self, enforcer):
        """Test getting budget for known agent."""
        budget = enforcer.get_budget("claude-code")

        assert budget.agent_id == "claude-code"
        assert budget.daily_input_tokens == 2_000_000  # Claude-code specific

    def test_set_budget(self, enforcer):
        """Test setting custom budget."""
        custom_budget = AgentBudget(
            agent_id="custom",
            daily_input_tokens=500_000,
        )

        enforcer.set_budget("custom", custom_budget)
        retrieved = enforcer.get_budget("custom")

        assert retrieved.daily_input_tokens == 500_000

    def test_check_budget_within_limits(self, enforcer):
        """Test budget check when within limits."""
        result = enforcer.check_budget(
            "claude-code",
            estimated_input_tokens=10_000,
            estimated_output_tokens=2_000,
        )

        assert result.status == BudgetStatus.WITHIN_BUDGET
        assert result.allowed is True

    def test_check_budget_exceeds_daily_limit(self, enforcer):
        """Test budget check when daily limit would be exceeded."""
        # Record usage up to limit
        budget = enforcer.get_budget("test-agent")
        budget.daily_input_tokens = 100_000

        # Record usage close to limit
        enforcer.record_usage("test-agent", input_tokens=99_000, output_tokens=0)

        # Check for request that would exceed
        result = enforcer.check_budget(
            "test-agent",
            estimated_input_tokens=2_000,
            estimated_output_tokens=0,
        )

        assert result.status == BudgetStatus.EXCEEDED
        assert result.allowed is False
        assert "daily input token limit" in result.reason

    def test_check_budget_exceeds_per_request_limit(self, enforcer):
        """Test budget check when per-request limit exceeded."""
        budget = enforcer.get_budget("test-agent")
        budget.max_tokens_per_request = 50_000

        result = enforcer.check_budget(
            "test-agent",
            estimated_input_tokens=60_000,
            estimated_output_tokens=0,
        )

        assert result.status == BudgetStatus.EXCEEDED
        assert result.allowed is False
        assert "max tokens per request" in result.reason

    def test_check_budget_warning_threshold(self, enforcer):
        """Test budget check returns warning when approaching limit."""
        budget = enforcer.get_budget("test-agent")
        budget.daily_input_tokens = 100_000
        budget.warning_threshold = 0.8

        # Record 85% usage
        enforcer.record_usage("test-agent", input_tokens=85_000, output_tokens=0)

        result = enforcer.check_budget("test-agent")

        assert result.status == BudgetStatus.WARNING
        assert result.allowed is True

    def test_check_budget_paused_agent(self, enforcer):
        """Test budget check for paused agent."""
        enforcer.pause_agent("test-agent", "Testing pause")

        result = enforcer.check_budget("test-agent")

        assert result.status == BudgetStatus.PAUSED
        assert result.allowed is False

    def test_record_usage(self, enforcer):
        """Test recording usage updates tracking."""
        enforcer.record_usage(
            "test-agent",
            input_tokens=10_000,
            output_tokens=2_000,
        )

        summary = enforcer.get_usage_summary("test-agent")

        assert summary["usage"]["input_tokens"] == 10_000
        assert summary["usage"]["output_tokens"] == 2_000
        assert summary["usage"]["request_count"] == 1

    def test_record_usage_accumulates(self, enforcer):
        """Test that usage accumulates across requests."""
        enforcer.record_usage("test-agent", input_tokens=5_000, output_tokens=1_000)
        enforcer.record_usage("test-agent", input_tokens=3_000, output_tokens=500)

        summary = enforcer.get_usage_summary("test-agent")

        assert summary["usage"]["input_tokens"] == 8_000
        assert summary["usage"]["output_tokens"] == 1_500
        assert summary["usage"]["request_count"] == 2

    def test_pause_resume_agent(self, enforcer):
        """Test pausing and resuming agent."""
        enforcer.pause_agent("test-agent", "Pausing for test")

        budget = enforcer.get_budget("test-agent")
        assert budget.is_paused is True
        assert budget.pause_reason == "Pausing for test"

        enforcer.resume_agent("test-agent")

        budget = enforcer.get_budget("test-agent")
        assert budget.is_paused is False
        assert budget.pause_reason is None

    def test_usage_summary_shows_percentages(self, enforcer):
        """Test usage summary includes percentage calculations."""
        budget = enforcer.get_budget("test-agent")
        budget.daily_input_tokens = 100_000
        budget.daily_output_tokens = 20_000

        enforcer.record_usage("test-agent", input_tokens=50_000, output_tokens=10_000)

        summary = enforcer.get_usage_summary("test-agent")

        assert summary["percentages"]["input_tokens"] == 0.5
        assert summary["percentages"]["output_tokens"] == 0.5


# =============================================================================
# MCP Budget Tests
# =============================================================================

class TestMCPServerBudget:
    """Tests for MCPServerBudget dataclass."""

    def test_default_server_budget(self):
        """Test default server budget values."""
        budget = MCPServerBudget(server_name="test-server")

        assert budget.daily_call_limit == 1000
        assert budget.hourly_call_limit == 100
        assert budget.is_blocked is False

    def test_default_mcp_budgets_exist(self):
        """Test that default MCP budgets are defined."""
        assert "filesystem" in DEFAULT_MCP_BUDGETS
        assert "github" in DEFAULT_MCP_BUDGETS

    def test_high_risk_tools_defined(self):
        """Test that high-risk tools are defined."""
        assert "filesystem" in HIGH_RISK_TOOLS
        assert "delete_file" in HIGH_RISK_TOOLS["filesystem"]
        assert "github" in HIGH_RISK_TOOLS
        assert "create_pull_request" in HIGH_RISK_TOOLS["github"]


class TestMCPUsageTracker:
    """Tests for MCPUsageTracker class."""

    @pytest.fixture
    def mock_db(self):
        """Create a mock database."""
        db = MagicMock()
        db.get_mcp_usage.return_value = None
        return db

    @pytest.fixture
    def tracker(self, mock_db):
        """Create an MCP usage tracker with mock DB."""
        return MCPUsageTracker(db=mock_db)

    def test_check_budget_within_limits(self, tracker):
        """Test budget check when within limits."""
        result = tracker.check_budget("filesystem", "read_file")

        assert result.status == MCPBudgetStatus.WITHIN_BUDGET
        assert result.allowed is True

    def test_check_budget_blocked_server(self, tracker):
        """Test budget check for blocked server."""
        tracker.block_server("test-server", "Testing block")

        result = tracker.check_budget("test-server", "some_tool")

        assert result.status == MCPBudgetStatus.BLOCKED
        assert result.allowed is False

    def test_check_budget_exceeds_daily_limit(self, tracker):
        """Test budget check when daily limit exceeded."""
        budget = tracker.get_server_budget("test-server")
        budget.daily_call_limit = 10

        # Record calls up to limit
        for _ in range(10):
            tracker.record_call("test-server", "test_tool")

        result = tracker.check_budget("test-server", "test_tool")

        assert result.status == MCPBudgetStatus.EXCEEDED
        assert result.allowed is False
        assert "Daily call limit" in result.reason

    def test_check_budget_exceeds_hourly_limit(self, tracker):
        """Test budget check when hourly limit exceeded."""
        budget = tracker.get_server_budget("test-server")
        budget.hourly_call_limit = 5

        # Record calls up to hourly limit
        for _ in range(5):
            tracker.record_call("test-server", "test_tool")

        result = tracker.check_budget("test-server", "test_tool")

        assert result.status == MCPBudgetStatus.EXCEEDED
        assert result.allowed is False
        assert "Hourly call limit" in result.reason

    def test_record_call(self, tracker):
        """Test recording MCP call updates tracking."""
        tracker.record_call("filesystem", "read_file")
        tracker.record_call("filesystem", "write_file")

        summary = tracker.get_server_summary("filesystem")

        assert summary["usage"]["total_calls"] == 2
        assert summary["tool_breakdown"]["read_file"] == 1
        assert summary["tool_breakdown"]["write_file"] == 1

    def test_requires_approval_high_risk(self, tracker):
        """Test that high-risk tools require approval."""
        assert tracker.requires_approval("filesystem", "delete_file") is True
        assert tracker.requires_approval("filesystem", "read_file") is False

    def test_block_unblock_server(self, tracker):
        """Test blocking and unblocking server."""
        tracker.block_server("test-server", "Testing")

        budget = tracker.get_server_budget("test-server")
        assert budget.is_blocked is True

        tracker.unblock_server("test-server")

        budget = tracker.get_server_budget("test-server")
        assert budget.is_blocked is False

    def test_block_unblock_tool(self, tracker):
        """Test blocking and unblocking specific tool."""
        tracker.block_tool("filesystem", "dangerous_tool", "Too dangerous")

        result = tracker.check_budget("filesystem", "dangerous_tool")
        assert result.status == MCPBudgetStatus.BLOCKED

        tracker.unblock_tool("filesystem", "dangerous_tool")

        result = tracker.check_budget("filesystem", "dangerous_tool")
        assert result.status == MCPBudgetStatus.WITHIN_BUDGET


# =============================================================================
# MCP Registry Tests
# =============================================================================

class TestMCPRegistry:
    """Tests for MCPRegistry class."""

    @pytest.fixture
    def registry(self):
        """Create an MCP registry."""
        return MCPRegistry()

    def test_default_allowed_servers(self, registry):
        """Test that default servers are allowed."""
        assert registry.is_allowed("filesystem") is True
        assert registry.is_allowed("github") is True
        assert registry.is_allowed("git") is True

    def test_default_blocked_servers(self, registry):
        """Test that dangerous servers are blocked by default."""
        assert registry.is_blocked("shell") is True
        assert registry.is_blocked("eval") is True
        assert registry.is_blocked("exec") is True

    def test_unknown_server_blocked_by_default(self):
        """Test that unknown servers are blocked when allow_unknown=False."""
        registry = MCPRegistry(allow_unknown_servers=False)

        result = registry.check_access("unknown-server")

        assert result.decision == RegistryDecision.BLOCK
        assert result.allowed is False

    def test_unknown_server_allowed_when_configured(self):
        """Test that unknown servers are allowed when allow_unknown=True."""
        registry = MCPRegistry(allow_unknown_servers=True)

        result = registry.check_access("unknown-server")

        assert result.decision == RegistryDecision.ALLOW
        assert result.allowed is True

    def test_check_access_allowed_server(self, registry):
        """Test checking access for allowed server."""
        result = registry.check_access("filesystem", "read_file")

        assert result.decision == RegistryDecision.ALLOW
        assert result.allowed is True

    def test_check_access_blocked_server(self, registry):
        """Test checking access for blocked server."""
        result = registry.check_access("shell", "execute")

        assert result.decision == RegistryDecision.BLOCK
        assert result.allowed is False

    def test_check_access_approval_required(self, registry):
        """Test checking access for approval-required tool."""
        result = registry.check_access("filesystem", "delete_file")

        assert result.decision == RegistryDecision.REQUIRE_APPROVAL
        assert result.allowed is False

    def test_allow_server(self, registry):
        """Test allowing a server."""
        registry.allow_server("new-server", "Adding for testing")

        assert registry.is_allowed("new-server") is True
        assert registry.is_blocked("new-server") is False

    def test_block_server(self, registry):
        """Test blocking a server."""
        registry.block_server("filesystem", "Blocking for test")

        assert registry.is_blocked("filesystem") is True

        result = registry.check_access("filesystem")
        assert result.decision == RegistryDecision.BLOCK

    def test_register_server(self, registry):
        """Test registering a new server."""
        registration = registry.register_server(
            server_name="custom-server",
            access_level=AccessLevel.ALLOWED,
            description="Custom MCP server",
            blocked_tools={"dangerous_tool"},
        )

        assert registration.server_name == "custom-server"
        assert "dangerous_tool" in registration.blocked_tools

        result = registry.check_access("custom-server", "safe_tool")
        assert result.allowed is True

        result = registry.check_access("custom-server", "dangerous_tool")
        assert result.allowed is False

    def test_register_tool(self, registry):
        """Test registering a specific tool."""
        registration = registry.register_tool(
            server_name="filesystem",
            tool_name="custom_tool",
            access_level=AccessLevel.APPROVAL_REQUIRED,
            requires_approval=True,
            risk_level="high",
        )

        assert registration.requires_approval is True

        result = registry.check_access("filesystem", "custom_tool")
        assert result.decision == RegistryDecision.REQUIRE_APPROVAL

    def test_require_approval_for_tool(self, registry):
        """Test marking a tool as requiring approval."""
        registry.require_approval_for_tool("filesystem", "new_risky_tool")

        result = registry.check_access("filesystem", "new_risky_tool")
        assert result.decision == RegistryDecision.REQUIRE_APPROVAL

    def test_high_risk_pattern_detection(self, registry):
        """Test that high-risk tool patterns are detected."""
        # Register database server first
        registry.allow_server("database")

        # Tool names matching patterns like 'delete*', 'drop*'
        result = registry.check_access("database", "delete_all_records")
        assert result.decision == RegistryDecision.REQUIRE_APPROVAL

        result = registry.check_access("database", "drop_table")
        assert result.decision == RegistryDecision.REQUIRE_APPROVAL

    def test_list_servers(self, registry):
        """Test listing registered servers."""
        servers = registry.list_servers()

        assert len(servers) > 0
        server_names = [s["server_name"] for s in servers]
        assert "filesystem" in server_names

    def test_export_import_config(self, registry):
        """Test exporting and importing configuration."""
        # Make some changes
        registry.allow_server("custom-server")
        registry.block_server("test-block")

        # Export
        config = registry.export_config()

        # Create new registry and import
        new_registry = MCPRegistry()
        new_registry.import_config(config)

        assert "custom-server" in new_registry.get_allowed_servers()
        assert "test-block" in new_registry.get_blocked_servers()


class TestRegistryWithAgentFiltering:
    """Tests for agent-level access control in registry."""

    @pytest.fixture
    def registry(self):
        """Create registry with agent filtering enabled."""
        return MCPRegistry()

    def test_tool_blocked_for_specific_agent(self, registry):
        """Test blocking a tool for a specific agent."""
        registration = registry.register_tool(
            server_name="github",
            tool_name="force_push",
            requires_approval=False,
        )
        registration.blocked_agents.add("junior-agent")

        result = registry.check_access("github", "force_push", agent_id="junior-agent")
        assert result.allowed is False

        result = registry.check_access("github", "force_push", agent_id="senior-agent")
        assert result.allowed is True

    def test_tool_allowed_only_for_specific_agents(self, registry):
        """Test tool with allowlist of agents."""
        # Register database server first
        registry.allow_server("database")

        registration = registry.register_tool(
            server_name="database",
            tool_name="admin_query",
        )
        registration.allowed_agents = {"admin-agent", "dba-agent"}

        result = registry.check_access("database", "admin_query", agent_id="admin-agent")
        assert result.allowed is True

        result = registry.check_access("database", "admin_query", agent_id="regular-agent")
        assert result.allowed is False


# =============================================================================
# Integration Tests
# =============================================================================

class TestBudgetIntegration:
    """Integration tests for budget components."""

    @pytest.fixture
    def mock_db(self):
        """Create a mock database."""
        db = MagicMock()
        db.get_daily_usage.return_value = None
        db.get_mcp_usage.return_value = None
        return db

    def test_agent_and_mcp_budget_together(self, mock_db):
        """Test using agent budget and MCP budget together."""
        agent_enforcer = BudgetEnforcer(db=mock_db)
        mcp_tracker = MCPUsageTracker(db=mock_db)
        registry = MCPRegistry()

        # Check agent budget
        agent_result = agent_enforcer.check_budget("claude-code")
        assert agent_result.allowed is True

        # Check MCP registry
        registry_result = registry.check_access("filesystem", "read_file")
        assert registry_result.allowed is True

        # Check MCP budget
        mcp_result = mcp_tracker.check_budget("filesystem", "read_file")
        assert mcp_result.allowed is True

        # Record both usages
        agent_enforcer.record_usage("claude-code", input_tokens=1000, output_tokens=100)
        mcp_tracker.record_call("filesystem", "read_file")

    def test_budget_enforcement_flow(self, mock_db):
        """Test the full budget enforcement flow."""
        enforcer = BudgetEnforcer(db=mock_db)
        tracker = MCPUsageTracker(db=mock_db)
        registry = MCPRegistry()

        # Simulate a request flow
        agent_id = "claude-code"
        server_name = "filesystem"
        tool_name = "write_file"

        # 1. Check agent budget
        budget_check = enforcer.check_budget(agent_id, estimated_input_tokens=5000)
        assert budget_check.allowed is True

        # 2. Check registry access
        access_check = registry.check_access(server_name, tool_name)
        # write_file requires approval by default
        assert access_check.decision == RegistryDecision.REQUIRE_APPROVAL

        # 3. After approval, check MCP budget
        mcp_check = tracker.check_budget(server_name, tool_name)
        assert mcp_check.allowed is True

        # 4. Execute and record
        enforcer.record_usage(agent_id, input_tokens=4800, output_tokens=200)
        tracker.record_call(server_name, tool_name)

        # 5. Verify recording
        agent_summary = enforcer.get_usage_summary(agent_id)
        assert agent_summary["usage"]["request_count"] == 1

        mcp_summary = tracker.get_server_summary(server_name)
        assert mcp_summary["usage"]["total_calls"] == 1


# =============================================================================
# CLI Session Persistence Tests
# =============================================================================

from agent_orchestrator.budget.cli_usage_tracker import (
    CLIUsageTracker,
    CLISession,
    RateLimitConfig,
    RateLimitState,
)
from agent_orchestrator.persistence.database import OrchestratorDB


class TestCLISessionPersistence:
    """Tests for CLI session persistence."""

    @pytest.fixture
    def db(self, tmp_path):
        """Create test database."""
        db_path = tmp_path / "test.db"
        return OrchestratorDB(str(db_path))

    @pytest.fixture
    def budget_enforcer(self):
        """Create mock budget enforcer."""
        enforcer = MagicMock(spec=BudgetEnforcer)
        enforcer.check_budget.return_value = BudgetCheckResult(
            status=BudgetStatus.WITHIN_BUDGET,
            allowed=True,
        )
        enforcer.get_budget.return_value = AgentBudget(
            agent_id="test-agent",
            daily_input_tokens=100000,
            daily_output_tokens=20000,
            daily_cost_limit=10.0,
        )
        enforcer._get_daily_usage.return_value = DailyUsage(
            agent_id="test-agent",
            date=date.today(),
        )
        return enforcer

    @pytest.fixture
    def tracker(self, budget_enforcer, db):
        """Create CLI usage tracker."""
        return CLIUsageTracker(
            budget_enforcer=budget_enforcer,
            db=db,
            auto_restore_state=False,  # Don't auto-restore for tests
        )

    def test_start_session(self, tracker):
        """Test starting a CLI session."""
        session = tracker.start_session("claude-code")

        assert session.agent_id == "claude-code"
        assert session.is_active is True
        assert session.request_count == 0
        assert "session-claude-code" in session.session_id

    def test_end_session(self, tracker):
        """Test ending a CLI session."""
        tracker.start_session("claude-code")

        session = tracker.end_session("claude-code")

        assert session is not None
        assert session.is_active is False
        assert tracker.get_session("claude-code") is None

    def test_get_or_start_session(self, tracker):
        """Test getting or starting a session."""
        # Should create new session
        session1 = tracker.get_or_start_session("claude-code")
        assert session1.agent_id == "claude-code"

        # Should return same session
        session2 = tracker.get_or_start_session("claude-code")
        assert session1.session_id == session2.session_id

    def test_session_stats_updated(self, tracker):
        """Test that session stats are updated on usage."""
        from agent_orchestrator.adapters.base import AgentResponse

        tracker.start_session("claude-code")
        response = AgentResponse(
            content="test",
            tokens_input=1000,
            tokens_output=200,
            cost=0.05,
            success=True,
        )

        import asyncio
        asyncio.get_event_loop().run_until_complete(
            tracker.record_cli_usage("claude-code", response)
        )

        session = tracker.get_session("claude-code")
        assert session.request_count == 1
        assert session.total_tokens == 1200
        assert session.total_cost_usd == 0.05

    def test_end_all_sessions(self, tracker):
        """Test ending all sessions."""
        tracker.start_session("claude-code")
        tracker.start_session("gemini-cli")

        count = tracker.end_all_sessions()

        assert count == 2
        assert tracker.get_session("claude-code") is None
        assert tracker.get_session("gemini-cli") is None

    def test_save_and_restore_state(self, tracker, db):
        """Test saving and restoring rate limit state."""
        agent_id = "claude-code"

        # Create some state
        state = tracker._get_rate_state(agent_id)
        state.recent_requests.append(datetime.now())
        state.consecutive_limit_hits = 2

        # Start a session and save
        tracker.start_session(agent_id)
        tracker._save_state(agent_id)

        # Clear in-memory state
        tracker._rate_states = {}

        # Restore
        restored = tracker._restore_state(agent_id)
        assert restored is True

        new_state = tracker._get_rate_state(agent_id)
        assert new_state.consecutive_limit_hits == 2

    def test_session_history(self, tracker, db):
        """Test getting session history."""
        # Create and end a session
        tracker.start_session("claude-code")
        tracker.end_session("claude-code")

        # Create another
        tracker.start_session("claude-code")
        tracker.end_session("claude-code")

        history = tracker.get_session_history("claude-code", limit=10)
        assert len(history) == 2


class TestDatabaseCLISessionOperations:
    """Tests for database CLI session operations."""

    @pytest.fixture
    def db(self, tmp_path):
        """Create test database."""
        db_path = tmp_path / "test.db"
        return OrchestratorDB(str(db_path))

    def test_create_and_get_session(self, db):
        """Test creating and retrieving a CLI session."""
        # First create the agent
        from agent_orchestrator.persistence.models import Agent
        db.create_agent(Agent(id="claude-code", tool="claude_code"))

        session_id = db.generate_session_id("claude-code")
        db.create_cli_session(session_id, "claude-code")

        session = db.get_active_cli_session("claude-code")
        assert session is not None
        assert session["id"] == session_id
        assert session["agent_id"] == "claude-code"
        assert session["is_active"] == 1

    def test_end_session(self, db):
        """Test ending a CLI session."""
        from agent_orchestrator.persistence.models import Agent
        db.create_agent(Agent(id="claude-code", tool="claude_code"))

        session_id = db.generate_session_id("claude-code")
        db.create_cli_session(session_id, "claude-code")

        db.end_cli_session(
            session_id=session_id,
            request_count=10,
            total_tokens=5000,
            total_cost_usd=0.25,
        )

        # Should no longer be active
        active = db.get_active_cli_session("claude-code")
        assert active is None

    def test_save_and_load_rate_limit_state(self, db):
        """Test saving and loading rate limit state."""
        import json

        db.save_rate_limit_state(
            agent_id="claude-code",
            session_id=None,
            recent_request_timestamps=json.dumps(["2024-01-01T12:00:00"]),
            recent_token_usage=json.dumps([["2024-01-01T12:00:00", 1000]]),
            cooldown_until=None,
            consecutive_limit_hits=3,
        )

        state = db.load_rate_limit_state("claude-code")
        assert state is not None
        assert state["consecutive_limit_hits"] == 3

        timestamps = json.loads(state["recent_request_timestamps"])
        assert len(timestamps) == 1

    def test_session_history(self, db):
        """Test getting session history."""
        from agent_orchestrator.persistence.models import Agent
        db.create_agent(Agent(id="claude-code", tool="claude_code"))

        # Create multiple sessions
        for i in range(3):
            session_id = f"session-{i}"
            db.create_cli_session(session_id, "claude-code")
            db.end_cli_session(session_id, request_count=i+1)

        history = db.get_cli_session_history("claude-code", limit=10)
        assert len(history) == 3
