#!/usr/bin/env python3
"""
Smart Context Selector - Intelligent chunk selection for LLM synthesis.

Prevents context window overflow by intelligently selecting the most relevant
chunks while respecting token budgets and document priorities.

Features:
- Token-aware chunk selection
- Document pinning priority
- Rerank score integration
- Overflow summarization
- Diversity sampling (avoid too many chunks from same document)

Usage:
    from context_selector import SmartContextSelector

    selector = SmartContextSelector(max_tokens=8000)
    selected = selector.select_context(
        chunks=search_results,
        pinned_doc_ids=['DOC_001', 'DOC_002'],
        rerank_scores={'chunk_1': 0.95, 'chunk_2': 0.87}
    )
"""

import sys
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Set, Tuple
from dataclasses import dataclass, field

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent))

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# =============================================================================
# CONFIGURATION
# =============================================================================

DEFAULT_MAX_TOKENS = 8000
PINNED_BUDGET_PERCENT = 0.6  # 60% of budget for pinned documents
DIVERSITY_MAX_CHUNKS_PER_DOC = 5  # Max chunks from single document
SUMMARY_TRIGGER_PERCENT = 0.1  # Summarize if >10% of budget remaining after selection


# =============================================================================
# TOKEN ESTIMATION
# =============================================================================

def estimate_tokens(text: str) -> int:
    """
    Estimate token count for text.

    Uses word count * 1.3 as rough approximation.
    For more accurate counts, use tiktoken.
    """
    if not text:
        return 0

    # Try tiktoken if available
    try:
        import tiktoken
        enc = tiktoken.get_encoding("cl100k_base")
        return len(enc.encode(text))
    except ImportError:
        pass

    # Fallback: word count * 1.3
    word_count = len(text.split())
    return int(word_count * 1.3)


# =============================================================================
# CHUNK SUMMARIZATION
# =============================================================================

def summarize_chunks(chunks: List[Dict[str, Any]], max_length: int = 500) -> str:
    """
    Create a brief summary of multiple chunks.

    For use when chunks don't fit in context but shouldn't be ignored.

    Args:
        chunks: List of chunk dicts with 'chunk_text' key
        max_length: Maximum character length of summary

    Returns:
        Summary string
    """
    if not chunks:
        return ""

    # Extract key info from each chunk
    summaries = []
    for chunk in chunks[:10]:  # Limit to 10 chunks
        text = chunk.get('chunk_text', '')[:200]
        doc_title = chunk.get('title', 'Unknown')
        summaries.append(f"From '{doc_title}': {text}...")

    combined = "\n".join(summaries)

    if len(combined) > max_length:
        combined = combined[:max_length] + "..."

    return combined


def llm_summarize_chunks(
    chunks: List[Dict[str, Any]],
    query: str = None,
    max_tokens: int = 300
) -> str:
    """
    Use LLM to create intelligent summary of overflow chunks.

    Args:
        chunks: Chunks to summarize
        query: Original query for context
        max_tokens: Max tokens for summary

    Returns:
        LLM-generated summary
    """
    try:
        from config import OPENAI_ENABLED, OPENAI_API_KEY
        if not OPENAI_ENABLED:
            return summarize_chunks(chunks)

        from openai import OpenAI
        client = OpenAI(api_key=OPENAI_API_KEY)

        # Build context
        chunk_texts = []
        for i, chunk in enumerate(chunks[:15], 1):
            title = chunk.get('title', 'Unknown')
            text = chunk.get('chunk_text', '')[:500]
            chunk_texts.append(f"[{i}] From '{title}':\n{text}")

        combined = "\n\n".join(chunk_texts)

        prompt = f"""Summarize the key points from these additional research sources.
Focus on information relevant to: {query or 'the research topic'}

Sources:
{combined}

Provide a concise summary (2-3 sentences) of the main points."""

        response = client.chat.completions.create(
            model='gpt-4o-mini',
            messages=[
                {"role": "system", "content": "You are a research assistant summarizing sources."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=max_tokens,
            temperature=0.3
        )

        return response.choices[0].message.content.strip()

    except Exception as e:
        logger.warning(f"LLM summarization failed: {e}")
        return summarize_chunks(chunks)


# =============================================================================
# SMART CONTEXT SELECTOR
# =============================================================================

@dataclass
class ContextSelectionResult:
    """Result of context selection."""
    selected_chunks: List[Dict[str, Any]]
    overflow_chunks: List[Dict[str, Any]]
    overflow_summary: Optional[str]
    total_tokens: int
    budget_tokens: int
    pinned_tokens: int
    regular_tokens: int
    summary_tokens: int
    stats: Dict[str, Any] = field(default_factory=dict)


class SmartContextSelector:
    """
    Intelligently select chunks for synthesis within token budget.

    Prioritizes:
    1. Chunks from pinned documents (up to 60% of budget)
    2. Chunks with highest rerank scores
    3. Diversity (avoid too many chunks from same document)
    4. Summaries of overflow chunks
    """

    def __init__(
        self,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        pinned_budget_percent: float = PINNED_BUDGET_PERCENT,
        max_chunks_per_doc: int = DIVERSITY_MAX_CHUNKS_PER_DOC,
        summarize_overflow: bool = True,
        use_llm_summary: bool = True
    ):
        """
        Initialize context selector.

        Args:
            max_tokens: Maximum tokens for context
            pinned_budget_percent: Percentage of budget for pinned documents
            max_chunks_per_doc: Max chunks from a single document (diversity)
            summarize_overflow: Whether to summarize overflow chunks
            use_llm_summary: Use LLM for summarization (otherwise extractive)
        """
        self.max_tokens = max_tokens
        self.pinned_budget_percent = pinned_budget_percent
        self.max_chunks_per_doc = max_chunks_per_doc
        self.summarize_overflow = summarize_overflow
        self.use_llm_summary = use_llm_summary

    def select_context(
        self,
        chunks: List[Dict[str, Any]],
        pinned_doc_ids: List[str] = None,
        rerank_scores: Dict[str, float] = None,
        query: str = None
    ) -> ContextSelectionResult:
        """
        Select optimal chunks for synthesis.

        Args:
            chunks: List of chunk dicts with 'chunk_id', 'chunk_text', 'document_id'
            pinned_doc_ids: Document IDs to prioritize
            rerank_scores: Dict mapping chunk_id to rerank score
            query: Original query (for summarization context)

        Returns:
            ContextSelectionResult with selected chunks and metadata
        """
        pinned_doc_ids = set(pinned_doc_ids or [])
        rerank_scores = rerank_scores or {}

        # Separate pinned vs regular chunks
        pinned_chunks = []
        regular_chunks = []

        for chunk in chunks:
            doc_id = chunk.get('document_id')
            if doc_id in pinned_doc_ids:
                pinned_chunks.append(chunk)
            else:
                regular_chunks.append(chunk)

        # Sort by rerank scores
        def get_score(chunk):
            return rerank_scores.get(chunk.get('chunk_id'), 0)

        pinned_chunks.sort(key=get_score, reverse=True)
        regular_chunks.sort(key=get_score, reverse=True)

        # Calculate budgets
        pinned_budget = int(self.max_tokens * self.pinned_budget_percent)
        regular_budget = self.max_tokens - pinned_budget

        # Select pinned chunks
        selected = []
        pinned_tokens = 0
        docs_seen: Dict[str, int] = {}  # doc_id -> chunk count

        for chunk in pinned_chunks:
            doc_id = chunk.get('document_id')
            chunk_tokens = estimate_tokens(chunk.get('chunk_text', ''))

            # Check diversity limit
            if docs_seen.get(doc_id, 0) >= self.max_chunks_per_doc:
                continue

            # Check budget
            if pinned_tokens + chunk_tokens > pinned_budget:
                continue

            selected.append(chunk)
            pinned_tokens += chunk_tokens
            docs_seen[doc_id] = docs_seen.get(doc_id, 0) + 1

        # Select regular chunks
        regular_tokens = 0
        remaining_budget = self.max_tokens - pinned_tokens

        for chunk in regular_chunks:
            doc_id = chunk.get('document_id')
            chunk_tokens = estimate_tokens(chunk.get('chunk_text', ''))

            # Check diversity limit
            if docs_seen.get(doc_id, 0) >= self.max_chunks_per_doc:
                continue

            # Check budget
            if regular_tokens + chunk_tokens > remaining_budget:
                continue

            selected.append(chunk)
            regular_tokens += chunk_tokens
            docs_seen[doc_id] = docs_seen.get(doc_id, 0) + 1

        # Identify overflow chunks
        selected_ids = {c.get('chunk_id') for c in selected}
        overflow = [c for c in chunks if c.get('chunk_id') not in selected_ids]

        # Summarize overflow if enabled
        overflow_summary = None
        summary_tokens = 0

        if self.summarize_overflow and overflow:
            remaining = self.max_tokens - pinned_tokens - regular_tokens

            if remaining > 100:  # Only summarize if we have room
                if self.use_llm_summary:
                    overflow_summary = llm_summarize_chunks(overflow, query)
                else:
                    overflow_summary = summarize_chunks(overflow)

                summary_tokens = estimate_tokens(overflow_summary)

                # Add summary as synthetic chunk
                if overflow_summary:
                    selected.append({
                        'chunk_id': 'OVERFLOW_SUMMARY',
                        'chunk_text': f"[Summary of {len(overflow)} additional sources]: {overflow_summary}",
                        'document_id': 'SUMMARY',
                        'title': 'Additional Sources Summary',
                        'is_summary': True
                    })

        # Build result
        total_tokens = pinned_tokens + regular_tokens + summary_tokens

        return ContextSelectionResult(
            selected_chunks=selected,
            overflow_chunks=overflow,
            overflow_summary=overflow_summary,
            total_tokens=total_tokens,
            budget_tokens=self.max_tokens,
            pinned_tokens=pinned_tokens,
            regular_tokens=regular_tokens,
            summary_tokens=summary_tokens,
            stats={
                'total_input_chunks': len(chunks),
                'selected_chunks': len(selected),
                'overflow_chunks': len(overflow),
                'pinned_doc_ids': list(pinned_doc_ids),
                'unique_docs': len(docs_seen),
                'budget_used_percent': round(total_tokens / self.max_tokens * 100, 1)
            }
        )

    def format_context_for_llm(
        self,
        result: ContextSelectionResult,
        include_metadata: bool = True
    ) -> str:
        """
        Format selected chunks as context string for LLM.

        Args:
            result: ContextSelectionResult from select_context
            include_metadata: Include source info with each chunk

        Returns:
            Formatted context string
        """
        lines = []

        for i, chunk in enumerate(result.selected_chunks, 1):
            if chunk.get('is_summary'):
                lines.append(f"\n{chunk['chunk_text']}\n")
            elif include_metadata:
                title = chunk.get('title', 'Unknown')
                author = chunk.get('author', '')
                doc_id = chunk.get('document_id', '')

                header = f"[Source {i}: {title}"
                if author:
                    header += f" by {author}"
                header += f" ({doc_id})]"

                lines.append(header)
                lines.append(chunk.get('chunk_text', ''))
                lines.append("")
            else:
                lines.append(chunk.get('chunk_text', ''))
                lines.append("")

        return "\n".join(lines)


# =============================================================================
# INTEGRATION HELPERS
# =============================================================================

def select_context_for_synthesis(
    chunks: List[Dict[str, Any]],
    max_tokens: int = 8000,
    pinned_doc_ids: List[str] = None,
    rerank_scores: Dict[str, float] = None,
    query: str = None
) -> Tuple[List[Dict[str, Any]], str]:
    """
    Convenience function for selecting context.

    Args:
        chunks: Search result chunks
        max_tokens: Token budget
        pinned_doc_ids: Documents to prioritize
        rerank_scores: Reranking scores by chunk_id
        query: Original query

    Returns:
        Tuple of (selected_chunks, formatted_context_string)
    """
    selector = SmartContextSelector(max_tokens=max_tokens)
    result = selector.select_context(
        chunks=chunks,
        pinned_doc_ids=pinned_doc_ids,
        rerank_scores=rerank_scores,
        query=query
    )

    context_str = selector.format_context_for_llm(result)

    return result.selected_chunks, context_str


# Type alias for imports
from typing import Tuple


# =============================================================================
# CLI FOR TESTING
# =============================================================================

def main():
    """Test context selection with sample data."""
    import argparse

    parser = argparse.ArgumentParser(description='Test Smart Context Selector')
    parser.add_argument('--max-tokens', type=int, default=2000,
                       help='Token budget for test')
    parser.add_argument('--chunks', type=int, default=20,
                       help='Number of test chunks to generate')
    args = parser.parse_args()

    # Generate test chunks
    test_chunks = []
    for i in range(args.chunks):
        doc_id = f"DOC_{(i % 5) + 1:03d}"  # 5 different documents
        test_chunks.append({
            'chunk_id': f"CHUNK_{i:03d}",
            'chunk_text': f"This is test chunk {i} from document {doc_id}. " * 20,
            'document_id': doc_id,
            'title': f"Test Document {(i % 5) + 1}",
            'author': f"Author {(i % 5) + 1}"
        })

    # Test rerank scores
    rerank_scores = {f"CHUNK_{i:03d}": 1.0 - (i * 0.04) for i in range(args.chunks)}

    # Pinned documents
    pinned = ['DOC_001', 'DOC_002']

    # Run selection
    selector = SmartContextSelector(
        max_tokens=args.max_tokens,
        summarize_overflow=True,
        use_llm_summary=False  # Use extractive for test
    )

    result = selector.select_context(
        chunks=test_chunks,
        pinned_doc_ids=pinned,
        rerank_scores=rerank_scores,
        query="Test query"
    )

    # Display results
    print(f"\n{'=' * 60}")
    print("SMART CONTEXT SELECTION TEST")
    print('=' * 60)
    print(f"\nInput: {len(test_chunks)} chunks")
    print(f"Budget: {args.max_tokens} tokens")
    print(f"Pinned docs: {pinned}")

    print(f"\n{'=' * 60}")
    print("RESULTS")
    print('=' * 60)
    print(f"Selected: {len(result.selected_chunks)} chunks")
    print(f"Overflow: {len(result.overflow_chunks)} chunks")
    print(f"Tokens used: {result.total_tokens} / {result.budget_tokens}")
    print(f"  - Pinned: {result.pinned_tokens}")
    print(f"  - Regular: {result.regular_tokens}")
    print(f"  - Summary: {result.summary_tokens}")

    print(f"\nStats: {result.stats}")

    if result.overflow_summary:
        print(f"\nOverflow Summary:\n{result.overflow_summary[:200]}...")


if __name__ == '__main__':
    main()
