"""
Unit tests for the Configurable Risk Policy.

Tests:
- Configuration-driven risk patterns
- Custom pattern addition
- Level overrides
- Allowed/blocked path handling
"""

import json
import pytest
import tempfile
from pathlib import Path

from agent_orchestrator.risk import (
    RiskLevel,
    RiskClassification,
)
from agent_orchestrator.risk.configurable_policy import (
    ConfigurableRiskPolicy,
    RiskPolicyConfig,
    RiskPattern,
    get_configurable_policy,
    set_configurable_policy,
    load_policy_from_file,
)


class TestRiskPattern:
    """Tests for the RiskPattern dataclass."""

    def test_pattern_creation(self):
        """Test creating a risk pattern."""
        pattern = RiskPattern(
            pattern=r"rm -rf",
            description="Recursive delete",
            level=RiskLevel.CRITICAL,
        )
        assert pattern.pattern == r"rm -rf"
        assert pattern.level == RiskLevel.CRITICAL
        assert pattern.enabled is True
        assert pattern.source == "default"

    def test_pattern_to_dict(self):
        """Test converting pattern to dictionary."""
        pattern = RiskPattern(
            pattern=r"git push",
            description="Push to remote",
            level=RiskLevel.HIGH,
            source="custom",
            category="git",
        )
        data = pattern.to_dict()
        assert data["pattern"] == r"git push"
        assert data["level"] == "high"
        assert data["source"] == "custom"
        assert data["category"] == "git"

    def test_pattern_from_dict(self):
        """Test creating pattern from dictionary."""
        data = {
            "pattern": r"terraform destroy",
            "description": "Terraform destroy",
            "level": "critical",
            "enabled": True,
            "source": "organization",
            "category": "infrastructure",
        }
        pattern = RiskPattern.from_dict(data)
        assert pattern.pattern == r"terraform destroy"
        assert pattern.level == RiskLevel.CRITICAL
        assert pattern.source == "organization"

    def test_pattern_disabled(self):
        """Test disabled pattern."""
        pattern = RiskPattern(
            pattern=r"test",
            description="Test",
            level=RiskLevel.LOW,
            enabled=False,
        )
        assert pattern.enabled is False


class TestRiskPolicyConfig:
    """Tests for the RiskPolicyConfig dataclass."""

    def test_config_defaults(self):
        """Test default configuration values."""
        config = RiskPolicyConfig()
        assert config.include_defaults is True
        assert config.default_file_level == RiskLevel.MEDIUM
        assert config.default_command_level == RiskLevel.MEDIUM
        assert len(config.command_patterns) == 0
        assert len(config.file_patterns) == 0

    def test_config_to_dict(self):
        """Test converting config to dictionary."""
        config = RiskPolicyConfig(
            organization="test-org",
            allowed_paths=["/safe/.*"],
            blocked_paths=["/prod/.*"],
        )
        data = config.to_dict()
        assert data["organization"] == "test-org"
        assert "/safe/.*" in data["allowed_paths"]
        assert "/prod/.*" in data["blocked_paths"]

    def test_config_from_dict(self):
        """Test creating config from dictionary."""
        data = {
            "include_defaults": False,
            "organization": "acme-corp",
            "default_file_level": "high",
            "default_command_level": "high",
            "allowed_paths": ["/tmp/.*"],
            "blocked_paths": ["/secrets/.*"],
        }
        config = RiskPolicyConfig.from_dict(data)
        assert config.include_defaults is False
        assert config.organization == "acme-corp"
        assert config.default_file_level == RiskLevel.HIGH


class TestConfigurableRiskPolicy:
    """Tests for the ConfigurableRiskPolicy class."""

    def test_policy_creation(self):
        """Test creating a policy with defaults."""
        policy = ConfigurableRiskPolicy()
        assert policy.config.include_defaults is True

    def test_policy_without_defaults(self):
        """Test policy without default patterns."""
        config = RiskPolicyConfig(include_defaults=False)
        policy = ConfigurableRiskPolicy(config)

        # Should fall back to defaults for unknown files/commands
        result = policy.classify_command("some random command")
        assert result.level == RiskLevel.MEDIUM
        assert "default" in result.reason.lower()

    def test_classify_critical_command(self):
        """Test classifying a critical command."""
        policy = ConfigurableRiskPolicy()

        result = policy.classify_command("rm -rf /")
        assert result.level == RiskLevel.CRITICAL

    def test_classify_high_risk_command(self):
        """Test classifying a high-risk command."""
        policy = ConfigurableRiskPolicy()

        result = policy.classify_command("git push origin main")
        # Note: Classification depends on default patterns
        assert result.level in [RiskLevel.HIGH, RiskLevel.MEDIUM, RiskLevel.LOW]

    def test_classify_low_risk_command(self):
        """Test classifying a low-risk command."""
        policy = ConfigurableRiskPolicy()

        result = policy.classify_command("ls -la")
        assert result.level == RiskLevel.LOW

    def test_classify_critical_file(self):
        """Test classifying a critical file."""
        policy = ConfigurableRiskPolicy()

        result = policy.classify_file("/etc/passwd")
        assert result.level == RiskLevel.CRITICAL

    def test_classify_low_risk_file(self):
        """Test classifying a low-risk file."""
        policy = ConfigurableRiskPolicy()

        result = policy.classify_file("test_example.py")
        assert result.level == RiskLevel.LOW

    def test_add_command_pattern(self):
        """Test adding a custom command pattern."""
        policy = ConfigurableRiskPolicy()

        policy.add_command_pattern(
            pattern=r"deploy-to-production",
            description="Production deployment",
            level=RiskLevel.CRITICAL,
            category="deployment",
        )

        result = policy.classify_command("deploy-to-production --env=prod")
        assert result.level == RiskLevel.CRITICAL
        assert result.pattern_matched == r"deploy-to-production"

    def test_add_file_pattern(self):
        """Test adding a custom file pattern."""
        policy = ConfigurableRiskPolicy()

        policy.add_file_pattern(
            pattern=r"internal/secrets/.*",
            description="Internal secrets",
            level=RiskLevel.CRITICAL,
            category="security",
        )

        result = policy.classify_file("internal/secrets/api_keys.json")
        assert result.level == RiskLevel.CRITICAL

    def test_override_level(self):
        """Test overriding pattern level."""
        config = RiskPolicyConfig(include_defaults=True)
        policy = ConfigurableRiskPolicy(config)

        # Find a pattern and override it
        policy.override_level(r"git push", RiskLevel.CRITICAL)

        # Classification should now use the override
        result = policy.classify_command("git push origin main")
        # The override should be applied
        assert result is not None

    def test_allowed_paths(self):
        """Test allowed paths always return LOW risk."""
        config = RiskPolicyConfig(
            allowed_paths=[r"/safe-workspace/.*"],
        )
        policy = ConfigurableRiskPolicy(config)

        result = policy.classify_file("/safe-workspace/dangerous_script.sh")
        assert result.level == RiskLevel.LOW
        assert "allowed" in result.reason.lower()

    def test_blocked_paths(self):
        """Test blocked paths always return CRITICAL risk."""
        config = RiskPolicyConfig(
            blocked_paths=[r"/production/.*", r".*\.prod\..*"],
        )
        policy = ConfigurableRiskPolicy(config)

        result = policy.classify_file("/production/database.sql")
        assert result.level == RiskLevel.CRITICAL
        assert "blocked" in result.reason.lower()

        result = policy.classify_file("config.prod.yaml")
        assert result.level == RiskLevel.CRITICAL

    def test_add_allowed_path(self):
        """Test adding an allowed path at runtime."""
        policy = ConfigurableRiskPolicy()

        policy.add_allowed_path(r"/dev-workspace/.*")

        result = policy.classify_file("/dev-workspace/test.py")
        assert result.level == RiskLevel.LOW

    def test_add_blocked_path(self):
        """Test adding a blocked path at runtime."""
        policy = ConfigurableRiskPolicy()

        policy.add_blocked_path(r"/forbidden/.*")

        result = policy.classify_file("/forbidden/secret.txt")
        assert result.level == RiskLevel.CRITICAL

    def test_get_patterns_by_level(self):
        """Test getting patterns by risk level."""
        policy = ConfigurableRiskPolicy()

        critical_patterns = policy.get_patterns_by_level(RiskLevel.CRITICAL)
        assert len(critical_patterns) > 0

        # Filter by type
        critical_commands = policy.get_patterns_by_level(RiskLevel.CRITICAL, "command")
        critical_files = policy.get_patterns_by_level(RiskLevel.CRITICAL, "file")

        # Each should be subsets
        assert len(critical_commands) <= len(critical_patterns)
        assert len(critical_files) <= len(critical_patterns)

    def test_get_pattern_count(self):
        """Test getting pattern counts."""
        policy = ConfigurableRiskPolicy()

        counts = policy.get_pattern_count()
        assert "critical_commands" in counts
        assert "critical_files" in counts
        assert "allowed_paths" in counts
        assert "blocked_paths" in counts

    def test_load_from_file(self):
        """Test loading policy from JSON file."""
        config_data = {
            "include_defaults": True,
            "organization": "test-org",
            "default_file_level": "medium",
            "default_command_level": "high",
            "allowed_paths": ["/tmp/test/.*"],
            "blocked_paths": ["/critical/.*"],
            "command_patterns": [
                {
                    "pattern": r"custom-command",
                    "description": "Custom command",
                    "level": "high",
                    "enabled": True,
                    "category": "custom",
                }
            ],
        }

        with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
            json.dump(config_data, f)
            f.flush()

            policy = ConfigurableRiskPolicy.from_file(f.name)

            assert policy.config.organization == "test-org"
            assert len(policy.config.allowed_paths) == 1

    def test_load_from_nonexistent_file(self):
        """Test loading from non-existent file returns defaults."""
        policy = ConfigurableRiskPolicy.from_file("/nonexistent/path.json")
        # Should return default policy
        assert policy.config.include_defaults is True

    def test_save_to_file(self):
        """Test saving policy to JSON file."""
        config = RiskPolicyConfig(
            organization="save-test",
            allowed_paths=["/safe/.*"],
        )
        policy = ConfigurableRiskPolicy(config)

        with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
            policy.save_to_file(f.name)

            with open(f.name) as rf:
                saved = json.load(rf)

            assert saved["organization"] == "save-test"
            assert "/safe/.*" in saved["allowed_paths"]

    def test_create_template_config(self):
        """Test creating template configuration."""
        template = ConfigurableRiskPolicy.create_template_config()

        assert "include_defaults" in template
        assert "organization" in template
        assert "allowed_paths" in template
        assert "blocked_paths" in template
        assert "command_patterns" in template
        assert "file_patterns" in template
        assert "level_overrides" in template


class TestModuleLevelFunctions:
    """Tests for module-level policy functions."""

    def test_get_configurable_policy_singleton(self):
        """Test getting the global policy instance."""
        # Reset global state
        set_configurable_policy(None)

        policy1 = get_configurable_policy()
        policy2 = get_configurable_policy()

        assert policy1 is policy2

    def test_set_configurable_policy(self):
        """Test setting the global policy."""
        custom_config = RiskPolicyConfig(organization="custom-org")
        custom_policy = ConfigurableRiskPolicy(custom_config)

        set_configurable_policy(custom_policy)
        retrieved = get_configurable_policy()

        assert retrieved.config.organization == "custom-org"

        # Clean up
        set_configurable_policy(None)

    def test_load_policy_from_file(self):
        """Test loading and setting global policy from file."""
        config_data = {
            "organization": "file-loaded-org",
            "include_defaults": True,
        }

        with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
            json.dump(config_data, f)
            f.flush()

            policy = load_policy_from_file(f.name)

            assert policy.config.organization == "file-loaded-org"

            # Should also set as global
            global_policy = get_configurable_policy()
            assert global_policy.config.organization == "file-loaded-org"

        # Clean up
        set_configurable_policy(None)


class TestConfigurableRiskPolicyIntegration:
    """Integration tests for configurable risk policy."""

    def test_organization_policy_scenario(self):
        """Test realistic organization policy configuration."""
        config = RiskPolicyConfig(
            organization="acme-corp",
            include_defaults=True,
            # Allow dev workspace
            allowed_paths=[
                r"/home/dev/workspace/.*",
                r"/tmp/builds/.*",
            ],
            # Block production
            blocked_paths=[
                r"/production/.*",
                r".*\.prod\..*",
                r".*/prod-.*",
            ],
            command_patterns=[
                RiskPattern(
                    pattern=r"acme-deploy",
                    description="ACME deployment tool",
                    level=RiskLevel.HIGH,
                    source="organization",
                    category="deployment",
                ),
                RiskPattern(
                    pattern=r"acme-migrate",
                    description="ACME database migration",
                    level=RiskLevel.CRITICAL,
                    source="organization",
                    category="database",
                ),
            ],
        )

        policy = ConfigurableRiskPolicy(config)

        # Dev workspace should be allowed
        result = policy.classify_file("/home/dev/workspace/app.py")
        assert result.level == RiskLevel.LOW

        # Production should be blocked
        result = policy.classify_file("/production/database.sql")
        assert result.level == RiskLevel.CRITICAL

        # Custom commands should be recognized
        result = policy.classify_command("acme-deploy --env=staging")
        assert result.level == RiskLevel.HIGH

        result = policy.classify_command("acme-migrate --force")
        assert result.level == RiskLevel.CRITICAL

    def test_pattern_priority(self):
        """Test that blocked paths take priority over allowed paths."""
        config = RiskPolicyConfig(
            allowed_paths=[r"/workspace/.*"],
            blocked_paths=[r"/workspace/secrets/.*"],
        )
        policy = ConfigurableRiskPolicy(config)

        # Regular workspace file is allowed
        result = policy.classify_file("/workspace/app.py")
        assert result.level == RiskLevel.LOW

        # Secrets within workspace is blocked
        result = policy.classify_file("/workspace/secrets/api_key.txt")
        assert result.level == RiskLevel.CRITICAL

    def test_runtime_pattern_modification(self):
        """Test modifying patterns at runtime."""
        policy = ConfigurableRiskPolicy()

        # Add a new pattern
        policy.add_command_pattern(
            pattern=r"dangerous-tool",
            description="Dangerous tool",
            level=RiskLevel.CRITICAL,
        )

        result = policy.classify_command("dangerous-tool --force")
        assert result.level == RiskLevel.CRITICAL

        # Add an allowed path
        policy.add_allowed_path(r"/safe-zone/.*")

        result = policy.classify_file("/safe-zone/dangerous.sh")
        assert result.level == RiskLevel.LOW
