"""
Knowledge Memory - RAG-based retrieval over operational knowledge.

This module provides:
- Indexing of /ops/ directory files
- Similarity-based retrieval
- Context-aware knowledge lookup

The system supports two backends:
1. Local embeddings (sentence-transformers) - default
2. pgvector (PostgreSQL) - for production scale

Usage:
    from agent_orchestrator.memory.knowledge import KnowledgeMemory

    km = KnowledgeMemory(ops_path=Path("ops"))
    km.index_all()

    results = km.retrieve("How to recover a stuck agent?", top_k=5)
"""

import hashlib
import json
import logging
import os
import re
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Optional

from .models import (
    MemoryItem,
    MemoryItemType,
    MemoryTier,
)


logger = logging.getLogger(__name__)


@dataclass
class SearchResult:
    """A single search result from knowledge retrieval."""

    item: MemoryItem
    score: float  # Similarity score (0-1)
    snippet: str  # Relevant snippet from content
    source_file: Optional[str] = None


@dataclass
class KnowledgeIndex:
    """In-memory index for knowledge items."""

    items: dict[str, MemoryItem] = field(default_factory=dict)
    embeddings: dict[str, list[float]] = field(default_factory=dict)
    file_hashes: dict[str, str] = field(default_factory=dict)
    last_indexed: Optional[datetime] = None


class KnowledgeMemory:
    """
    RAG-ready knowledge memory over /ops/ directory.

    Provides semantic search over:
    - ADRs (Architecture Decision Records)
    - Runbooks (Operational procedures)
    - Patterns (Known issues and fixes)
    - Prompts (Agent templates)
    - Task summaries

    Retrieval policy:
    - Maximum 5 items per query
    - Always include relevant ADRs for task type
    - Include runbooks only when needed
    """

    MAX_RETRIEVAL_ITEMS = 5

    def __init__(
        self,
        ops_path: Path,
        use_embeddings: bool = False,
        embedding_model: str = "all-MiniLM-L6-v2",
    ):
        """
        Initialize Knowledge Memory.

        Args:
            ops_path: Path to the /ops/ directory
            use_embeddings: Whether to use semantic embeddings (requires sentence-transformers)
            embedding_model: Name of the sentence-transformers model
        """
        self.ops_path = Path(ops_path)
        self.use_embeddings = use_embeddings
        self.embedding_model_name = embedding_model
        self._index = KnowledgeIndex()

        # Lazy load embedding model
        self._embedding_model = None

        # File type to item type mapping
        self._file_patterns = {
            "decisions": MemoryItemType.ADR,
            "runbooks": MemoryItemType.RUNBOOK,
            "patterns": MemoryItemType.PATTERN,
            "prompts": MemoryItemType.PROMPT,
        }

    def _get_embedding_model(self):
        """Lazy load the embedding model."""
        if self._embedding_model is None and self.use_embeddings:
            try:
                from sentence_transformers import SentenceTransformer
                self._embedding_model = SentenceTransformer(self.embedding_model_name)
                logger.info(f"Loaded embedding model: {self.embedding_model_name}")
            except ImportError:
                logger.warning("sentence-transformers not installed, falling back to keyword search")
                self.use_embeddings = False
        return self._embedding_model

    def _compute_embedding(self, text: str) -> Optional[list[float]]:
        """Compute embedding for text."""
        model = self._get_embedding_model()
        if model:
            embedding = model.encode(text)
            return embedding.tolist()
        return None

    def _compute_file_hash(self, content: str) -> str:
        """Compute hash of file content."""
        return hashlib.md5(content.encode()).hexdigest()

    def index_all(self, force: bool = False) -> int:
        """
        Index all files in the /ops/ directory.

        Args:
            force: Force re-indexing even if content unchanged

        Returns:
            Number of items indexed
        """
        if not self.ops_path.exists():
            logger.warning(f"Ops path does not exist: {self.ops_path}")
            return 0

        indexed_count = 0

        # Index each subdirectory
        for subdir, item_type in self._file_patterns.items():
            subdir_path = self.ops_path / subdir
            if subdir_path.exists():
                indexed_count += self._index_directory(subdir_path, item_type, force)

        # Index project_state.json
        state_file = self.ops_path / "project_state.json"
        if state_file.exists():
            indexed_count += self._index_project_state(state_file, force)

        self._index.last_indexed = datetime.now()
        logger.info(f"Indexed {indexed_count} knowledge items")

        return indexed_count

    def _index_directory(self, directory: Path, item_type: MemoryItemType, force: bool) -> int:
        """Index all files in a directory."""
        indexed = 0

        for file_path in directory.glob("**/*.md"):
            if self._index_file(file_path, item_type, force):
                indexed += 1

        return indexed

    def _index_file(self, file_path: Path, item_type: MemoryItemType, force: bool) -> bool:
        """
        Index a single file.

        Returns:
            True if file was indexed/updated
        """
        try:
            content = file_path.read_text()
        except Exception as e:
            logger.error(f"Failed to read {file_path}: {e}")
            return False

        # Check if content changed
        content_hash = self._compute_file_hash(content)
        file_key = str(file_path.relative_to(self.ops_path))

        if not force and file_key in self._index.file_hashes:
            if self._index.file_hashes[file_key] == content_hash:
                return False  # No change

        # Parse the file
        title = self._extract_title(content, file_path.name)
        tags = self._extract_tags(content, item_type)

        # Create memory item
        item_id = f"km-{content_hash[:12]}"
        item = MemoryItem(
            id=item_id,
            tier=MemoryTier.WORKING_KNOWLEDGE,
            item_type=item_type,
            title=title,
            content=content,
            tags=tags,
            source_links=[str(file_path)],
            created_by="indexer",
            metadata={"source_file": file_key},
        )

        # Store item
        self._index.items[item_id] = item
        self._index.file_hashes[file_key] = content_hash

        # Compute embedding if enabled
        if self.use_embeddings:
            embedding = self._compute_embedding(f"{title}\n{content}")
            if embedding:
                self._index.embeddings[item_id] = embedding

        return True

    def _index_project_state(self, state_file: Path, force: bool) -> int:
        """Index the project_state.json file."""
        try:
            content = state_file.read_text()
            state_data = json.loads(content)
        except Exception as e:
            logger.error(f"Failed to read project state: {e}")
            return 0

        content_hash = self._compute_file_hash(content)
        file_key = "project_state.json"

        if not force and file_key in self._index.file_hashes:
            if self._index.file_hashes[file_key] == content_hash:
                return 0

        # Create summary content
        summary_parts = []
        summary_parts.append(f"Project: {state_data.get('project', {}).get('name', 'Unknown')}")
        summary_parts.append(f"Phase: {state_data.get('current_phase', 'Unknown')}")

        if state_data.get("constraints"):
            summary_parts.append("\nConstraints:")
            for c in state_data["constraints"]:
                summary_parts.append(f"- {c}")

        if state_data.get("decisions"):
            summary_parts.append("\nRecent Decisions:")
            for d in state_data["decisions"][-3:]:
                summary_parts.append(f"- {d.get('decision', '')}")

        summary = "\n".join(summary_parts)

        item_id = f"km-state-{content_hash[:8]}"
        item = MemoryItem(
            id=item_id,
            tier=MemoryTier.AUTHORITATIVE,
            item_type=MemoryItemType.PROJECT_STATE,
            title="Current Project State",
            content=summary,
            tags=["project-state", "authoritative"],
            source_links=["ops/project_state.json"],
            created_by="indexer",
            metadata={"raw_state": state_data},
        )

        self._index.items[item_id] = item
        self._index.file_hashes[file_key] = content_hash

        return 1

    def _extract_title(self, content: str, filename: str) -> str:
        """Extract title from markdown content."""
        # Look for H1 header
        match = re.search(r"^#\s+(.+)$", content, re.MULTILINE)
        if match:
            return match.group(1).strip()

        # Use filename as fallback
        return filename.replace(".md", "").replace("-", " ").title()

    def _extract_tags(self, content: str, item_type: MemoryItemType) -> list[str]:
        """Extract tags from content."""
        tags = [item_type.value]

        # Look for explicit tags in content
        match = re.search(r"\*\*Tags\*\*:\s*(.+)$", content, re.MULTILINE)
        if match:
            raw_tags = match.group(1)
            tags.extend([t.strip() for t in raw_tags.split(",")])

        # Extract keywords from content
        keywords = ["agent", "recovery", "workspace", "git", "tmux", "error",
                    "api", "budget", "approval", "risk", "memory"]
        content_lower = content.lower()
        for keyword in keywords:
            if keyword in content_lower and keyword not in tags:
                tags.append(keyword)

        return list(set(tags))

    def retrieve(
        self,
        query: str,
        top_k: int = 5,
        item_types: Optional[list[MemoryItemType]] = None,
        tags: Optional[list[str]] = None,
    ) -> list[SearchResult]:
        """
        Retrieve relevant knowledge items.

        Args:
            query: Search query
            top_k: Maximum number of results (capped at MAX_RETRIEVAL_ITEMS)
            item_types: Filter by item types
            tags: Filter by tags

        Returns:
            List of SearchResult ordered by relevance
        """
        top_k = min(top_k, self.MAX_RETRIEVAL_ITEMS)

        # Get candidate items
        candidates = list(self._index.items.values())

        # Filter by item types
        if item_types:
            candidates = [c for c in candidates if c.item_type in item_types]

        # Filter by tags
        if tags:
            candidates = [c for c in candidates if c.matches_tags(tags)]

        # Score candidates
        if self.use_embeddings and self._index.embeddings:
            results = self._score_by_embedding(query, candidates, top_k)
        else:
            results = self._score_by_keywords(query, candidates, top_k)

        return results

    def _score_by_keywords(
        self,
        query: str,
        candidates: list[MemoryItem],
        top_k: int,
    ) -> list[SearchResult]:
        """Score candidates by keyword matching."""
        query_words = set(query.lower().split())

        scored = []
        for item in candidates:
            # Calculate score based on keyword overlap
            content_words = set(item.content.lower().split())
            title_words = set(item.title.lower().split())
            tag_words = set(t.lower() for t in item.tags)

            # Weight title and tags higher
            title_overlap = len(query_words & title_words)
            tag_overlap = len(query_words & tag_words)
            content_overlap = len(query_words & content_words)

            score = (title_overlap * 3 + tag_overlap * 2 + content_overlap) / max(len(query_words), 1)

            if score > 0:
                snippet = self._extract_snippet(item.content, query)
                scored.append(SearchResult(
                    item=item,
                    score=min(score / 5, 1.0),  # Normalize to 0-1
                    snippet=snippet,
                    source_file=item.metadata.get("source_file"),
                ))

        # Sort by score descending
        scored.sort(key=lambda x: x.score, reverse=True)

        return scored[:top_k]

    def _score_by_embedding(
        self,
        query: str,
        candidates: list[MemoryItem],
        top_k: int,
    ) -> list[SearchResult]:
        """Score candidates by embedding similarity."""
        query_embedding = self._compute_embedding(query)
        if not query_embedding:
            return self._score_by_keywords(query, candidates, top_k)

        scored = []
        for item in candidates:
            if item.id not in self._index.embeddings:
                continue

            item_embedding = self._index.embeddings[item.id]
            score = self._cosine_similarity(query_embedding, item_embedding)

            snippet = self._extract_snippet(item.content, query)
            scored.append(SearchResult(
                item=item,
                score=score,
                snippet=snippet,
                source_file=item.metadata.get("source_file"),
            ))

        # Sort by score descending
        scored.sort(key=lambda x: x.score, reverse=True)

        return scored[:top_k]

    def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
        """Compute cosine similarity between two vectors."""
        dot_product = sum(x * y for x, y in zip(a, b))
        norm_a = sum(x * x for x in a) ** 0.5
        norm_b = sum(x * x for x in b) ** 0.5

        if norm_a == 0 or norm_b == 0:
            return 0.0

        return dot_product / (norm_a * norm_b)

    def _extract_snippet(self, content: str, query: str, max_length: int = 200) -> str:
        """Extract relevant snippet from content."""
        query_words = query.lower().split()

        # Find the best paragraph
        paragraphs = content.split("\n\n")
        best_paragraph = paragraphs[0] if paragraphs else content

        for para in paragraphs:
            para_lower = para.lower()
            if any(word in para_lower for word in query_words):
                best_paragraph = para
                break

        # Truncate if needed
        if len(best_paragraph) > max_length:
            best_paragraph = best_paragraph[:max_length] + "..."

        return best_paragraph.strip()

    def get_relevant_runbooks(self, task_description: str, top_k: int = 3) -> list[SearchResult]:
        """Get runbooks relevant to a task."""
        return self.retrieve(
            task_description,
            top_k=top_k,
            item_types=[MemoryItemType.RUNBOOK],
        )

    def get_relevant_adrs(self, topic: str, top_k: int = 3) -> list[SearchResult]:
        """Get ADRs relevant to a topic."""
        return self.retrieve(
            topic,
            top_k=top_k,
            item_types=[MemoryItemType.ADR],
        )

    def get_known_issues(self, error_pattern: str, top_k: int = 3) -> list[SearchResult]:
        """Get known issues matching an error pattern."""
        return self.retrieve(
            error_pattern,
            top_k=top_k,
            item_types=[MemoryItemType.PATTERN, MemoryItemType.ISSUE_FIX],
        )

    def get_project_state_summary(self) -> Optional[str]:
        """Get the current project state summary."""
        for item in self._index.items.values():
            if item.item_type == MemoryItemType.PROJECT_STATE:
                return item.content
        return None

    def get_agent_prompt(self, agent_type: str) -> Optional[str]:
        """Get the prompt template for an agent type."""
        for item in self._index.items.values():
            if item.item_type == MemoryItemType.PROMPT:
                if agent_type.lower() in item.title.lower():
                    return item.content
        return None

    def get_item_count(self) -> dict[str, int]:
        """Get count of items by type."""
        counts: dict[str, int] = {}
        for item in self._index.items.values():
            key = item.item_type.value
            counts[key] = counts.get(key, 0) + 1
        return counts

    def rebuild_index(self) -> int:
        """Force rebuild the entire index."""
        self._index = KnowledgeIndex()
        return self.index_all(force=True)
