"""
MCP Registry - Allowlist and blocklist for MCP servers and tools.

Implements:
- Server allowlist/blocklist
- Tool-level access control
- Approval-required tools registry
- Dynamic registration and configuration
"""

from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Optional, Dict, List, Set, Any, Callable
import re
import logging

logger = logging.getLogger(__name__)


class AccessLevel(Enum):
    """Access level for MCP servers/tools."""
    ALLOWED = "allowed"
    BLOCKED = "blocked"
    APPROVAL_REQUIRED = "approval_required"
    UNKNOWN = "unknown"  # Not in registry


class RegistryDecision(Enum):
    """Decision from registry check."""
    ALLOW = "allow"
    BLOCK = "block"
    REQUIRE_APPROVAL = "require_approval"


@dataclass
class ServerRegistration:
    """Registration entry for an MCP server."""

    server_name: str
    access_level: AccessLevel = AccessLevel.ALLOWED

    # Server info
    description: Optional[str] = None
    version: Optional[str] = None
    source_url: Optional[str] = None

    # Access control
    allowed_tools: Optional[Set[str]] = None  # None = all tools allowed
    blocked_tools: Set[str] = field(default_factory=set)
    approval_required_tools: Set[str] = field(default_factory=set)

    # Trust settings
    trust_level: str = "standard"  # trusted, standard, untrusted
    requires_review: bool = False

    # Audit
    registered_at: datetime = field(default_factory=datetime.now)
    registered_by: str = "system"
    last_used_at: Optional[datetime] = None

    # Notes
    notes: Optional[str] = None


@dataclass
class ToolRegistration:
    """Registration entry for a specific tool."""

    server_name: str
    tool_name: str
    access_level: AccessLevel = AccessLevel.ALLOWED

    # Tool info
    description: Optional[str] = None
    risk_level: str = "low"  # low, medium, high, critical

    # Access control
    requires_approval: bool = False
    allowed_agents: Optional[Set[str]] = None  # None = all agents
    blocked_agents: Set[str] = field(default_factory=set)

    # Usage patterns
    usage_patterns: List[str] = field(default_factory=list)  # Expected use cases

    # Audit
    registered_at: datetime = field(default_factory=datetime.now)
    last_used_at: Optional[datetime] = None


@dataclass
class RegistryCheckResult:
    """Result of a registry access check."""

    decision: RegistryDecision
    allowed: bool

    server_name: str
    tool_name: Optional[str] = None
    agent_id: Optional[str] = None

    # Details
    reason: str = ""
    access_level: AccessLevel = AccessLevel.UNKNOWN

    # For approval required
    approval_context: Optional[Dict[str, Any]] = None


# Default allowed servers (built-in, trusted)
DEFAULT_ALLOWED_SERVERS: Set[str] = {
    "filesystem",
    "github",
    "git",
    "memory",
    "sqlite",
    "postgres",
    "fetch",
    "browser-use",
    "sequential-thinking",
}

# Default blocked servers (known risky)
DEFAULT_BLOCKED_SERVERS: Set[str] = {
    "shell",  # Direct shell access - too dangerous
    "eval",   # Code evaluation
    "exec",   # Command execution
    "system", # System commands
}

# Tools that always require approval
DEFAULT_APPROVAL_REQUIRED_TOOLS: Dict[str, Set[str]] = {
    "filesystem": {
        "delete_file",
        "delete_directory",
        "move_file",
        "write_file",  # Writing arbitrary files
    },
    "github": {
        "create_pull_request",
        "merge_pull_request",
        "delete_branch",
        "push_to_remote",
        "create_release",
    },
    "git": {
        "push",
        "force_push",
        "delete_branch",
        "reset_hard",
    },
    "postgres": {
        "execute_query",  # Could be destructive
        "drop_table",
        "truncate_table",
        "alter_table",
    },
    "sqlite": {
        "execute_query",
        "drop_table",
    },
}

# High-risk tool patterns (regex)
HIGH_RISK_PATTERNS: List[str] = [
    r"delete.*",
    r"drop.*",
    r"truncate.*",
    r"remove.*",
    r"destroy.*",
    r"exec.*",
    r"eval.*",
    r"run.*command.*",
    r"shell.*",
    r"force.*push.*",
]


class MCPRegistry:
    """Registry for MCP server and tool access control."""

    def __init__(
        self,
        allowed_servers: Optional[Set[str]] = None,
        blocked_servers: Optional[Set[str]] = None,
        approval_required_tools: Optional[Dict[str, Set[str]]] = None,
        allow_unknown_servers: bool = False,
    ):
        """
        Initialize MCP registry.

        Args:
            allowed_servers: Set of allowed server names
            blocked_servers: Set of blocked server names
            approval_required_tools: Map of server -> tools requiring approval
            allow_unknown_servers: Whether to allow servers not in registry
        """
        self._allowed_servers = allowed_servers or DEFAULT_ALLOWED_SERVERS.copy()
        self._blocked_servers = blocked_servers or DEFAULT_BLOCKED_SERVERS.copy()
        self._approval_required_tools = approval_required_tools or DEFAULT_APPROVAL_REQUIRED_TOOLS.copy()
        self._allow_unknown = allow_unknown_servers

        # Detailed registrations
        self._server_registrations: Dict[str, ServerRegistration] = {}
        self._tool_registrations: Dict[str, ToolRegistration] = {}

        # Compiled patterns for high-risk detection
        self._high_risk_patterns = [re.compile(p, re.IGNORECASE) for p in HIGH_RISK_PATTERNS]

        # Callbacks for approval
        self._approval_callbacks: List[Callable] = []

        # Initialize default registrations
        self._init_defaults()

    def _init_defaults(self) -> None:
        """Initialize default server registrations."""
        for server in self._allowed_servers:
            if server not in self._server_registrations:
                self._server_registrations[server] = ServerRegistration(
                    server_name=server,
                    access_level=AccessLevel.ALLOWED,
                    approval_required_tools=self._approval_required_tools.get(server, set()),
                )

        for server in self._blocked_servers:
            if server not in self._server_registrations:
                self._server_registrations[server] = ServerRegistration(
                    server_name=server,
                    access_level=AccessLevel.BLOCKED,
                    description="Blocked by default policy",
                )

    def register_server(
        self,
        server_name: str,
        access_level: AccessLevel = AccessLevel.ALLOWED,
        description: Optional[str] = None,
        allowed_tools: Optional[Set[str]] = None,
        blocked_tools: Optional[Set[str]] = None,
        approval_required_tools: Optional[Set[str]] = None,
        registered_by: str = "user",
    ) -> ServerRegistration:
        """
        Register an MCP server.

        Args:
            server_name: Name of the MCP server
            access_level: Access level for the server
            description: Description of the server
            allowed_tools: Specific tools to allow (None = all)
            blocked_tools: Tools to block
            approval_required_tools: Tools requiring approval
            registered_by: Who registered the server

        Returns:
            ServerRegistration object
        """
        registration = ServerRegistration(
            server_name=server_name,
            access_level=access_level,
            description=description,
            allowed_tools=allowed_tools,
            blocked_tools=blocked_tools or set(),
            approval_required_tools=approval_required_tools or set(),
            registered_by=registered_by,
        )

        self._server_registrations[server_name] = registration

        # Update quick lookup sets
        if access_level == AccessLevel.ALLOWED:
            self._allowed_servers.add(server_name)
            self._blocked_servers.discard(server_name)
        elif access_level == AccessLevel.BLOCKED:
            self._blocked_servers.add(server_name)
            self._allowed_servers.discard(server_name)

        if approval_required_tools:
            self._approval_required_tools[server_name] = approval_required_tools

        logger.info(f"Registered MCP server: {server_name} ({access_level.value})")
        return registration

    def register_tool(
        self,
        server_name: str,
        tool_name: str,
        access_level: AccessLevel = AccessLevel.ALLOWED,
        requires_approval: bool = False,
        risk_level: str = "low",
        description: Optional[str] = None,
    ) -> ToolRegistration:
        """
        Register a specific tool.

        Args:
            server_name: MCP server name
            tool_name: Tool name
            access_level: Access level for the tool
            requires_approval: Whether tool requires approval
            risk_level: Risk level (low/medium/high/critical)
            description: Tool description

        Returns:
            ToolRegistration object
        """
        key = f"{server_name}:{tool_name}"

        registration = ToolRegistration(
            server_name=server_name,
            tool_name=tool_name,
            access_level=access_level,
            requires_approval=requires_approval,
            risk_level=risk_level,
            description=description,
        )

        self._tool_registrations[key] = registration

        # Update approval required set
        if requires_approval:
            if server_name not in self._approval_required_tools:
                self._approval_required_tools[server_name] = set()
            self._approval_required_tools[server_name].add(tool_name)

        logger.info(f"Registered tool: {server_name}:{tool_name} ({access_level.value})")
        return registration

    def allow_server(self, server_name: str, reason: Optional[str] = None) -> None:
        """Add server to allowlist."""
        self._allowed_servers.add(server_name)
        self._blocked_servers.discard(server_name)

        if server_name in self._server_registrations:
            self._server_registrations[server_name].access_level = AccessLevel.ALLOWED
        else:
            self.register_server(server_name, AccessLevel.ALLOWED, description=reason)

        logger.info(f"Allowed MCP server: {server_name}")

    def block_server(self, server_name: str, reason: Optional[str] = None) -> None:
        """Add server to blocklist."""
        self._blocked_servers.add(server_name)
        self._allowed_servers.discard(server_name)

        if server_name in self._server_registrations:
            self._server_registrations[server_name].access_level = AccessLevel.BLOCKED
            self._server_registrations[server_name].notes = reason
        else:
            self.register_server(server_name, AccessLevel.BLOCKED, description=reason)

        logger.info(f"Blocked MCP server: {server_name}")

    def require_approval_for_tool(self, server_name: str, tool_name: str) -> None:
        """Mark a tool as requiring approval."""
        if server_name not in self._approval_required_tools:
            self._approval_required_tools[server_name] = set()
        self._approval_required_tools[server_name].add(tool_name)

        key = f"{server_name}:{tool_name}"
        if key in self._tool_registrations:
            self._tool_registrations[key].requires_approval = True
        else:
            self.register_tool(server_name, tool_name, requires_approval=True)

    def _is_high_risk_tool(self, tool_name: str) -> bool:
        """Check if tool name matches high-risk patterns."""
        for pattern in self._high_risk_patterns:
            if pattern.match(tool_name):
                return True
        return False

    def check_access(
        self,
        server_name: str,
        tool_name: Optional[str] = None,
        agent_id: Optional[str] = None,
    ) -> RegistryCheckResult:
        """
        Check if access is allowed for server/tool.

        Args:
            server_name: MCP server name
            tool_name: Optional tool name
            agent_id: Optional agent making the request

        Returns:
            RegistryCheckResult with decision
        """
        # Check server-level access
        if server_name in self._blocked_servers:
            return RegistryCheckResult(
                decision=RegistryDecision.BLOCK,
                allowed=False,
                server_name=server_name,
                tool_name=tool_name,
                agent_id=agent_id,
                reason=f"Server '{server_name}' is blocked",
                access_level=AccessLevel.BLOCKED,
            )

        # Check if server is known
        is_known = server_name in self._allowed_servers or server_name in self._server_registrations

        if not is_known and not self._allow_unknown:
            return RegistryCheckResult(
                decision=RegistryDecision.BLOCK,
                allowed=False,
                server_name=server_name,
                tool_name=tool_name,
                agent_id=agent_id,
                reason=f"Unknown server '{server_name}' not in allowlist",
                access_level=AccessLevel.UNKNOWN,
            )

        # Get server registration
        registration = self._server_registrations.get(server_name)

        # If no tool specified, check server access only
        if tool_name is None:
            if registration and registration.access_level == AccessLevel.APPROVAL_REQUIRED:
                return RegistryCheckResult(
                    decision=RegistryDecision.REQUIRE_APPROVAL,
                    allowed=False,
                    server_name=server_name,
                    reason=f"Server '{server_name}' requires approval",
                    access_level=AccessLevel.APPROVAL_REQUIRED,
                )

            return RegistryCheckResult(
                decision=RegistryDecision.ALLOW,
                allowed=True,
                server_name=server_name,
                access_level=AccessLevel.ALLOWED,
            )

        # Check tool-level access
        tool_key = f"{server_name}:{tool_name}"
        tool_registration = self._tool_registrations.get(tool_key)

        # Check if tool is blocked
        if tool_registration and tool_registration.access_level == AccessLevel.BLOCKED:
            return RegistryCheckResult(
                decision=RegistryDecision.BLOCK,
                allowed=False,
                server_name=server_name,
                tool_name=tool_name,
                agent_id=agent_id,
                reason=f"Tool '{tool_name}' is blocked",
                access_level=AccessLevel.BLOCKED,
            )

        # Check server's blocked tools
        if registration and tool_name in registration.blocked_tools:
            return RegistryCheckResult(
                decision=RegistryDecision.BLOCK,
                allowed=False,
                server_name=server_name,
                tool_name=tool_name,
                agent_id=agent_id,
                reason=f"Tool '{tool_name}' is blocked for server '{server_name}'",
                access_level=AccessLevel.BLOCKED,
            )

        # Check server's allowed tools (if restricted)
        if registration and registration.allowed_tools is not None:
            if tool_name not in registration.allowed_tools:
                return RegistryCheckResult(
                    decision=RegistryDecision.BLOCK,
                    allowed=False,
                    server_name=server_name,
                    tool_name=tool_name,
                    agent_id=agent_id,
                    reason=f"Tool '{tool_name}' not in allowed tools for '{server_name}'",
                    access_level=AccessLevel.BLOCKED,
                )

        # Check agent-level blocking
        if tool_registration and agent_id:
            if agent_id in tool_registration.blocked_agents:
                return RegistryCheckResult(
                    decision=RegistryDecision.BLOCK,
                    allowed=False,
                    server_name=server_name,
                    tool_name=tool_name,
                    agent_id=agent_id,
                    reason=f"Agent '{agent_id}' blocked from tool '{tool_name}'",
                    access_level=AccessLevel.BLOCKED,
                )

            if tool_registration.allowed_agents is not None:
                if agent_id not in tool_registration.allowed_agents:
                    return RegistryCheckResult(
                        decision=RegistryDecision.BLOCK,
                        allowed=False,
                        server_name=server_name,
                        tool_name=tool_name,
                        agent_id=agent_id,
                        reason=f"Agent '{agent_id}' not allowed for tool '{tool_name}'",
                        access_level=AccessLevel.BLOCKED,
                    )

        # Check if approval required
        needs_approval = False
        approval_reason = ""

        # Check tool registration
        if tool_registration and tool_registration.requires_approval:
            needs_approval = True
            approval_reason = f"Tool '{tool_name}' requires approval"

        # Check server's approval required list
        elif server_name in self._approval_required_tools:
            if tool_name in self._approval_required_tools[server_name]:
                needs_approval = True
                approval_reason = f"Tool '{tool_name}' requires approval for server '{server_name}'"

        # Check high-risk patterns
        elif self._is_high_risk_tool(tool_name):
            needs_approval = True
            approval_reason = f"Tool '{tool_name}' matches high-risk pattern"

        if needs_approval:
            return RegistryCheckResult(
                decision=RegistryDecision.REQUIRE_APPROVAL,
                allowed=False,
                server_name=server_name,
                tool_name=tool_name,
                agent_id=agent_id,
                reason=approval_reason,
                access_level=AccessLevel.APPROVAL_REQUIRED,
                approval_context={
                    "server": server_name,
                    "tool": tool_name,
                    "agent": agent_id,
                },
            )

        # Update last used timestamp
        if registration:
            registration.last_used_at = datetime.now()
        if tool_registration:
            tool_registration.last_used_at = datetime.now()

        return RegistryCheckResult(
            decision=RegistryDecision.ALLOW,
            allowed=True,
            server_name=server_name,
            tool_name=tool_name,
            agent_id=agent_id,
            access_level=AccessLevel.ALLOWED,
        )

    def is_allowed(self, server_name: str, tool_name: Optional[str] = None) -> bool:
        """Quick check if server/tool is allowed."""
        result = self.check_access(server_name, tool_name)
        return result.allowed

    def is_blocked(self, server_name: str) -> bool:
        """Check if server is blocked."""
        return server_name in self._blocked_servers

    def requires_approval(self, server_name: str, tool_name: str) -> bool:
        """Check if tool requires approval."""
        result = self.check_access(server_name, tool_name)
        return result.decision == RegistryDecision.REQUIRE_APPROVAL

    def get_allowed_servers(self) -> Set[str]:
        """Get set of allowed servers."""
        return self._allowed_servers.copy()

    def get_blocked_servers(self) -> Set[str]:
        """Get set of blocked servers."""
        return self._blocked_servers.copy()

    def get_server_registration(self, server_name: str) -> Optional[ServerRegistration]:
        """Get registration for a server."""
        return self._server_registrations.get(server_name)

    def get_tool_registration(self, server_name: str, tool_name: str) -> Optional[ToolRegistration]:
        """Get registration for a tool."""
        return self._tool_registrations.get(f"{server_name}:{tool_name}")

    def list_servers(self) -> List[Dict[str, Any]]:
        """List all registered servers."""
        servers = []
        for name, reg in self._server_registrations.items():
            servers.append({
                "server_name": name,
                "access_level": reg.access_level.value,
                "description": reg.description,
                "blocked_tools": list(reg.blocked_tools),
                "approval_required_tools": list(reg.approval_required_tools),
                "registered_at": reg.registered_at.isoformat(),
                "last_used_at": reg.last_used_at.isoformat() if reg.last_used_at else None,
            })
        return servers

    def list_tools(self, server_name: Optional[str] = None) -> List[Dict[str, Any]]:
        """List registered tools, optionally filtered by server."""
        tools = []
        for key, reg in self._tool_registrations.items():
            if server_name and reg.server_name != server_name:
                continue

            tools.append({
                "server_name": reg.server_name,
                "tool_name": reg.tool_name,
                "access_level": reg.access_level.value,
                "requires_approval": reg.requires_approval,
                "risk_level": reg.risk_level,
                "description": reg.description,
            })
        return tools

    def export_config(self) -> Dict[str, Any]:
        """Export registry configuration."""
        return {
            "allowed_servers": list(self._allowed_servers),
            "blocked_servers": list(self._blocked_servers),
            "approval_required_tools": {
                k: list(v) for k, v in self._approval_required_tools.items()
            },
            "allow_unknown_servers": self._allow_unknown,
            "servers": self.list_servers(),
            "tools": self.list_tools(),
        }

    def import_config(self, config: Dict[str, Any]) -> None:
        """Import registry configuration."""
        if "allowed_servers" in config:
            self._allowed_servers = set(config["allowed_servers"])

        if "blocked_servers" in config:
            self._blocked_servers = set(config["blocked_servers"])

        if "approval_required_tools" in config:
            self._approval_required_tools = {
                k: set(v) for k, v in config["approval_required_tools"].items()
            }

        if "allow_unknown_servers" in config:
            self._allow_unknown = config["allow_unknown_servers"]

        # Re-initialize defaults with new config
        self._init_defaults()
        logger.info("Imported MCP registry configuration")


# Singleton instance
_registry: Optional[MCPRegistry] = None


def get_registry() -> Optional[MCPRegistry]:
    """Get the global MCP registry instance."""
    return _registry


def set_registry(registry: MCPRegistry) -> None:
    """Set the global MCP registry instance."""
    global _registry
    _registry = registry


def is_server_allowed(server_name: str) -> bool:
    """Quick check if server is allowed."""
    registry = get_registry()
    if registry is None:
        return True  # No registry configured, allow all

    return registry.is_allowed(server_name)


def is_tool_allowed(server_name: str, tool_name: str) -> bool:
    """Quick check if tool is allowed."""
    registry = get_registry()
    if registry is None:
        return True  # No registry configured, allow all

    return registry.is_allowed(server_name, tool_name)
