"""
Memory Write Gate - Validates and routes memory patches.

The Write Gate ensures that:
1. No secrets are stored in memory
2. Patches fit the expected schema
3. Evidence is provided for changes
4. Risk level is correctly classified
5. Appropriate approval is obtained

Usage:
    from agent_orchestrator.memory.write_gate import MemoryWriteGate

    gate = MemoryWriteGate(db, secret_redactor)
    result = gate.validate_patch(patch)

    if result.valid:
        if gate.requires_approval(patch):
            gate.route_to_approval(patch)
        else:
            gate.apply_patch(patch)
"""

import logging
import re
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Callable, Optional

from .models import (
    MemoryItem,
    MemoryPatch,
    MemoryTier,
    MemoryItemType,
    PatchOperation,
    PatchRiskLevel,
    PatchValidationResult,
)
from ..secrets.redactor import SecretRedactor
from ..persistence.database import OrchestratorDB


logger = logging.getLogger(__name__)


# Patterns that indicate secrets
SECRET_PATTERNS = [
    r"sk-[a-zA-Z0-9]{20,}",  # Anthropic API keys
    r"sk-proj-[a-zA-Z0-9]{20,}",  # OpenAI project keys
    r"xox[baprs]-[a-zA-Z0-9-]+",  # Slack tokens
    r"ghp_[a-zA-Z0-9]{36}",  # GitHub PAT
    r"gho_[a-zA-Z0-9]{36}",  # GitHub OAuth
    r"github_pat_[a-zA-Z0-9_]{22,}",  # GitHub fine-grained PAT
    r"AIza[a-zA-Z0-9_-]{35}",  # Google API key
    r"ya29\.[a-zA-Z0-9_-]+",  # Google OAuth token
    r"AKIA[A-Z0-9]{16}",  # AWS Access Key
    r"password\s*[=:]\s*['\"][^'\"]+['\"]",  # Password assignments
    r"secret\s*[=:]\s*['\"][^'\"]+['\"]",  # Secret assignments
    r"-----BEGIN (RSA |EC |OPENSSH )?PRIVATE KEY-----",  # Private keys
    r"-----BEGIN CERTIFICATE-----",  # Certificates
]

# Operations that always require approval
HIGH_RISK_OPERATIONS = {
    PatchOperation.DELETE,
}

# Item types that require approval for modification
PROTECTED_ITEM_TYPES = {
    MemoryItemType.ADR,
    MemoryItemType.POLICY,
    MemoryItemType.PROJECT_STATE,
    MemoryItemType.CONSTRAINT,
}


@dataclass
class ApprovalRequest:
    """Request for human approval of a memory patch."""

    patch: MemoryPatch
    reason: str
    risk_assessment: str
    evidence_summary: str
    recommended_action: str  # "approve", "reject", "review"
    requested_at: datetime
    timeout_minutes: int = 60


class MemoryWriteGate:
    """
    Validates and routes memory patches.

    The Write Gate is the single point of entry for all memory modifications.
    It ensures safety, consistency, and auditability.
    """

    def __init__(
        self,
        db: OrchestratorDB,
        secret_redactor: Optional[SecretRedactor] = None,
        auto_apply_low_risk: bool = True,
        approval_callback: Optional[Callable[[ApprovalRequest], bool]] = None,
    ):
        """
        Initialize the Memory Write Gate.

        Args:
            db: Database for persistence and approval queue
            secret_redactor: Optional custom secret redactor
            auto_apply_low_risk: Whether to auto-apply LOW risk patches
            approval_callback: Callback for requesting approval
        """
        self.db = db
        self.secret_redactor = secret_redactor or SecretRedactor()
        self.auto_apply_low_risk = auto_apply_low_risk
        self.approval_callback = approval_callback

        # Compile secret patterns
        self._secret_patterns = [re.compile(p, re.IGNORECASE) for p in SECRET_PATTERNS]

        # In-memory storage for memory items (could be backed by DB)
        self._memory_items: dict[str, MemoryItem] = {}
        self._pending_patches: dict[str, MemoryPatch] = {}

    def validate_patch(self, patch: MemoryPatch) -> PatchValidationResult:
        """
        Validate a memory patch.

        Checks:
        1. No secrets in content
        2. Schema validation
        3. Evidence provided
        4. Risk level appropriate
        5. Target exists (for updates)

        Args:
            patch: The patch to validate

        Returns:
            PatchValidationResult with validation status
        """
        result = PatchValidationResult(valid=True)

        # Check for secrets
        self._check_no_secrets(patch, result)

        # Check schema
        self._check_schema(patch, result)

        # Check evidence
        self._check_evidence(patch, result)

        # Check risk level
        computed_risk = self._compute_risk_level(patch)
        result.computed_risk_level = computed_risk

        if computed_risk != patch.risk_level:
            result.add_warning(
                f"Risk level adjusted from {patch.risk_level.value} to {computed_risk.value}"
            )

        # Check target exists for updates
        if patch.operation in {PatchOperation.UPDATE, PatchOperation.TAG, PatchOperation.SUPERSEDE}:
            if patch.target_id and patch.target_id not in self._memory_items:
                result.add_error(f"Target item not found: {patch.target_id}")

        return result

    def _check_no_secrets(self, patch: MemoryPatch, result: PatchValidationResult) -> None:
        """Check that patch content contains no secrets."""
        content_to_check = []

        if patch.new_item:
            content_to_check.append(patch.new_item.title)
            content_to_check.append(patch.new_item.content)
            content_to_check.extend(patch.new_item.tags)

        if patch.changes:
            for key, value in patch.changes.items():
                if isinstance(value, str):
                    content_to_check.append(value)
                elif isinstance(value, list):
                    content_to_check.extend(str(v) for v in value)

        combined_content = " ".join(content_to_check)

        for pattern in self._secret_patterns:
            if pattern.search(combined_content):
                result.add_error(
                    f"Secret detected in patch content (pattern: {pattern.pattern[:30]}...)"
                )
                logger.warning(f"Patch {patch.id} rejected: secret detected")
                break

    def _check_schema(self, patch: MemoryPatch, result: PatchValidationResult) -> None:
        """Check that patch follows expected schema."""
        # ADD operations must have new_item
        if patch.operation == PatchOperation.ADD:
            if not patch.new_item:
                result.add_error("ADD operation requires new_item")
            else:
                # Validate item fields
                if not patch.new_item.title:
                    result.add_error("Memory item must have a title")
                if not patch.new_item.content:
                    result.add_error("Memory item must have content")

        # UPDATE operations must have target_id and changes or new_item
        if patch.operation == PatchOperation.UPDATE:
            if not patch.target_id:
                result.add_error("UPDATE operation requires target_id")
            if not patch.changes and not patch.new_item:
                result.add_error("UPDATE operation requires changes or new_item")

        # TAG operations must have target_id and tag changes
        if patch.operation == PatchOperation.TAG:
            if not patch.target_id:
                result.add_error("TAG operation requires target_id")
            if "add_tags" not in patch.changes and "remove_tags" not in patch.changes:
                result.add_error("TAG operation requires add_tags or remove_tags in changes")

        # DELETE operations must have target_id
        if patch.operation == PatchOperation.DELETE:
            if not patch.target_id:
                result.add_error("DELETE operation requires target_id")

    def _check_evidence(self, patch: MemoryPatch, result: PatchValidationResult) -> None:
        """Check that patch has supporting evidence."""
        # Evidence required for content changes
        needs_evidence = patch.operation in {
            PatchOperation.ADD,
            PatchOperation.UPDATE,
            PatchOperation.SUPERSEDE,
        }

        if needs_evidence:
            if not patch.evidence_links:
                result.add_warning("No evidence links provided for content change")

            if not patch.reason:
                result.add_warning("No reason provided for change")

    def _compute_risk_level(self, patch: MemoryPatch) -> PatchRiskLevel:
        """Compute the appropriate risk level for a patch."""
        # DELETE is always HIGH
        if patch.operation == PatchOperation.DELETE:
            return PatchRiskLevel.HIGH

        # Changes to protected item types are HIGH
        if patch.new_item and patch.new_item.item_type in PROTECTED_ITEM_TYPES:
            return PatchRiskLevel.HIGH

        # Check if modifying a protected item
        if patch.target_id and patch.target_id in self._memory_items:
            target = self._memory_items[patch.target_id]
            if target.item_type in PROTECTED_ITEM_TYPES:
                return PatchRiskLevel.HIGH
            if target.tier in {MemoryTier.IMMUTABLE_AUDIT, MemoryTier.AUTHORITATIVE}:
                return PatchRiskLevel.HIGH

        # TAG operations are LOW
        if patch.operation == PatchOperation.TAG:
            return PatchRiskLevel.LOW

        # ADD operations for working knowledge are MEDIUM
        if patch.operation == PatchOperation.ADD:
            if patch.new_item and patch.new_item.tier == MemoryTier.WORKING_KNOWLEDGE:
                if patch.new_item.item_type == MemoryItemType.ISSUE_FIX:
                    return PatchRiskLevel.LOW  # Issue fixes are low risk
                return PatchRiskLevel.MEDIUM

        # Default to MEDIUM
        return PatchRiskLevel.MEDIUM

    def requires_approval(self, patch: MemoryPatch) -> bool:
        """Check if a patch requires human approval."""
        risk_level = self._compute_risk_level(patch)

        # HIGH always requires approval
        if risk_level == PatchRiskLevel.HIGH:
            return True

        # MEDIUM may require approval based on config
        if risk_level == PatchRiskLevel.MEDIUM:
            return True  # Conservative default

        # LOW auto-applies if configured
        return not self.auto_apply_low_risk

    def route_to_approval(self, patch: MemoryPatch) -> str:
        """
        Route a patch to the approval queue.

        Returns:
            Approval request ID
        """
        risk_level = self._compute_risk_level(patch)

        # Store pending patch
        self._pending_patches[patch.id] = patch
        patch.status = "pending_approval"

        # Create approval request
        request = ApprovalRequest(
            patch=patch,
            reason=patch.reason or "Memory modification requested",
            risk_assessment=f"Risk level: {risk_level.value}",
            evidence_summary=", ".join(patch.evidence_links) if patch.evidence_links else "None",
            recommended_action="review",
            requested_at=datetime.now(),
        )

        # If we have an approval callback, use it
        if self.approval_callback:
            try:
                approved = self.approval_callback(request)
                if approved:
                    return self.apply_patch(patch)
                else:
                    patch.status = "rejected"
                    patch.rejection_reason = "Rejected via approval callback"
                    return patch.id
            except Exception as e:
                logger.error(f"Approval callback failed: {e}")
                patch.status = "pending"

        logger.info(f"Patch {patch.id} routed to approval queue (risk: {risk_level.value})")
        return patch.id

    def apply_patch(self, patch: MemoryPatch) -> str:
        """
        Apply a validated and approved patch.

        Args:
            patch: The patch to apply

        Returns:
            ID of the affected memory item
        """
        if patch.operation == PatchOperation.ADD:
            return self._apply_add(patch)
        elif patch.operation == PatchOperation.UPDATE:
            return self._apply_update(patch)
        elif patch.operation == PatchOperation.TAG:
            return self._apply_tag(patch)
        elif patch.operation == PatchOperation.SUPERSEDE:
            return self._apply_supersede(patch)
        elif patch.operation == PatchOperation.ARCHIVE:
            return self._apply_archive(patch)
        elif patch.operation == PatchOperation.DELETE:
            return self._apply_delete(patch)
        else:
            raise ValueError(f"Unknown operation: {patch.operation}")

    def _apply_add(self, patch: MemoryPatch) -> str:
        """Apply an ADD operation."""
        if not patch.new_item:
            raise ValueError("ADD operation requires new_item")

        item = patch.new_item
        self._memory_items[item.id] = item

        patch.status = "applied"
        logger.info(f"Added memory item: {item.id} ({item.title})")

        return item.id

    def _apply_update(self, patch: MemoryPatch) -> str:
        """Apply an UPDATE operation."""
        if not patch.target_id:
            raise ValueError("UPDATE operation requires target_id")

        if patch.new_item:
            # Full replacement
            item = patch.new_item
            item.id = patch.target_id
            item.updated_at = datetime.now()
            self._memory_items[item.id] = item
        else:
            # Partial update
            item = self._memory_items[patch.target_id]
            for key, value in patch.changes.items():
                if hasattr(item, key):
                    setattr(item, key, value)
            item.updated_at = datetime.now()

        patch.status = "applied"
        logger.info(f"Updated memory item: {patch.target_id}")

        return patch.target_id

    def _apply_tag(self, patch: MemoryPatch) -> str:
        """Apply a TAG operation."""
        if not patch.target_id:
            raise ValueError("TAG operation requires target_id")

        item = self._memory_items[patch.target_id]

        if "add_tags" in patch.changes:
            for tag in patch.changes["add_tags"]:
                if tag not in item.tags:
                    item.tags.append(tag)

        if "remove_tags" in patch.changes:
            for tag in patch.changes["remove_tags"]:
                if tag in item.tags:
                    item.tags.remove(tag)

        item.updated_at = datetime.now()
        patch.status = "applied"
        logger.info(f"Updated tags for: {patch.target_id}")

        return patch.target_id

    def _apply_supersede(self, patch: MemoryPatch) -> str:
        """Apply a SUPERSEDE operation."""
        if not patch.target_id or not patch.new_item:
            raise ValueError("SUPERSEDE requires target_id and new_item")

        # Mark old item as superseded
        old_item = self._memory_items[patch.target_id]
        old_item.superseded_by = patch.new_item.id
        old_item.updated_at = datetime.now()

        # Add new item
        self._memory_items[patch.new_item.id] = patch.new_item

        patch.status = "applied"
        logger.info(f"Superseded {patch.target_id} with {patch.new_item.id}")

        return patch.new_item.id

    def _apply_archive(self, patch: MemoryPatch) -> str:
        """Apply an ARCHIVE operation."""
        if not patch.target_id:
            raise ValueError("ARCHIVE operation requires target_id")

        item = self._memory_items[patch.target_id]
        item.archived = True
        item.updated_at = datetime.now()

        patch.status = "applied"
        logger.info(f"Archived memory item: {patch.target_id}")

        return patch.target_id

    def _apply_delete(self, patch: MemoryPatch) -> str:
        """Apply a DELETE operation."""
        if not patch.target_id:
            raise ValueError("DELETE operation requires target_id")

        item = self._memory_items.pop(patch.target_id, None)
        if item:
            logger.warning(f"Deleted memory item: {patch.target_id} ({item.title})")
        else:
            logger.warning(f"Attempted to delete non-existent item: {patch.target_id}")

        patch.status = "applied"

        return patch.target_id

    def process_patch(self, patch: MemoryPatch) -> tuple[bool, str]:
        """
        Process a patch through validation, approval routing, and application.

        Args:
            patch: The patch to process

        Returns:
            Tuple of (success, message)
        """
        # Validate
        result = self.validate_patch(patch)

        if not result.valid:
            patch.status = "rejected"
            patch.rejection_reason = "; ".join(result.errors)
            return False, f"Validation failed: {patch.rejection_reason}"

        # Update risk level if computed differently
        if result.computed_risk_level:
            patch.risk_level = result.computed_risk_level

        # Check approval
        if self.requires_approval(patch):
            self.route_to_approval(patch)
            return True, f"Patch {patch.id} routed to approval (risk: {patch.risk_level.value})"

        # Apply
        try:
            item_id = self.apply_patch(patch)
            return True, f"Patch applied successfully, item: {item_id}"
        except Exception as e:
            patch.status = "failed"
            return False, f"Failed to apply patch: {e}"

    def get_memory_item(self, item_id: str) -> Optional[MemoryItem]:
        """Get a memory item by ID."""
        return self._memory_items.get(item_id)

    def get_all_items(self, tier: Optional[MemoryTier] = None) -> list[MemoryItem]:
        """Get all memory items, optionally filtered by tier."""
        items = list(self._memory_items.values())

        if tier is not None:
            items = [i for i in items if i.tier == tier]

        return items

    def get_active_items(self) -> list[MemoryItem]:
        """Get all active (non-archived, non-superseded) items."""
        return [i for i in self._memory_items.values() if i.is_active()]

    def get_pending_patches(self) -> list[MemoryPatch]:
        """Get all pending patches awaiting approval."""
        return [
            p for p in self._pending_patches.values()
            if p.status in {"pending", "pending_approval"}
        ]

    def approve_patch(self, patch_id: str, reviewer: str) -> tuple[bool, str]:
        """
        Approve a pending patch.

        Args:
            patch_id: ID of the patch to approve
            reviewer: ID of the reviewer

        Returns:
            Tuple of (success, message)
        """
        patch = self._pending_patches.get(patch_id)
        if not patch:
            return False, f"Patch not found: {patch_id}"

        if patch.status != "pending_approval":
            return False, f"Patch not pending approval: {patch.status}"

        patch.reviewed_by = reviewer
        patch.reviewed_at = datetime.now()

        try:
            item_id = self.apply_patch(patch)
            return True, f"Patch approved and applied, item: {item_id}"
        except Exception as e:
            patch.status = "failed"
            return False, f"Failed to apply approved patch: {e}"

    def reject_patch(self, patch_id: str, reviewer: str, reason: str) -> tuple[bool, str]:
        """
        Reject a pending patch.

        Args:
            patch_id: ID of the patch to reject
            reviewer: ID of the reviewer
            reason: Reason for rejection

        Returns:
            Tuple of (success, message)
        """
        patch = self._pending_patches.get(patch_id)
        if not patch:
            return False, f"Patch not found: {patch_id}"

        patch.status = "rejected"
        patch.reviewed_by = reviewer
        patch.reviewed_at = datetime.now()
        patch.rejection_reason = reason

        logger.info(f"Patch {patch_id} rejected by {reviewer}: {reason}")

        return True, f"Patch rejected: {reason}"
