#!/usr/bin/env python3
"""
Autonomous Research Agent for Research Development Framework v2.1

This agent implements an iterative research workflow that goes beyond
single-pass search → save → log. It can:

1. Plan: Break down complex research questions into sub-queries
2. Search: Execute searches using hybrid search with re-ranking
3. Analyze: Evaluate if gathered information is sufficient
4. Iterate: Self-correct and search for gaps
5. Synthesize: Combine results into a coherent research report

The agent uses LLM reasoning to determine when it has enough information
and what additional searches might be needed.

Usage:
    # Interactive mode
    python research_agent.py "Compare Steiner and Jung on dreams"

    # With specific parameters
    python research_agent.py "What is the etheric body?" --max-iterations 5 --output report.md

    # Batch mode from file
    python research_agent.py --batch questions.txt --output-dir reports/

Requirements:
    - OpenAI API key for LLM reasoning (or local LLM via Ollama)
    - sentence-transformers for re-ranking (optional but recommended)

Example Agent Flow:
    1. User: "Compare Steiner and Jung on dreams"
    2. Agent plans: ["Steiner views on dreams", "Jung views on dreams", "Compare both"]
    3. Agent searches "Steiner dreams" → analyzes → finds 5 relevant chunks
    4. Agent searches "Jung dreams" → analyzes → finds 3 relevant chunks
    5. Agent notices gap: "Need more on Jung's archetypes in dreams"
    6. Agent searches "Jung archetypes dreams" → finds 4 more chunks
    7. Agent synthesizes all findings into structured report
"""

import os
import sys
import json
import argparse
import logging
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field

# Add parent directory for imports
sys.path.insert(0, str(Path(__file__).parent))

from config import (
    OPENAI_ENABLED, OPENAI_API_KEY, INTELLIGENCE_MODE,
    CLOUD_MODELS, LOCAL_LLM_ENDPOINT, LOCAL_LLM_MODEL
)
from db_utils import (
    hybrid_search_with_rerank,
    graphrag_search,
    get_chunk_with_context,
    log_search
)

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


# =============================================================================
# DATA STRUCTURES
# =============================================================================

@dataclass
class SearchResult:
    """Represents a single search result."""
    chunk_id: str
    document_id: str
    title: str
    chunk_text: str
    score: float
    retrieval_method: str
    concepts: List[str] = field(default_factory=list)


@dataclass
class ResearchStep:
    """Represents one step in the research process."""
    query: str
    reasoning: str
    results: List[SearchResult]
    timestamp: str
    iteration: int


@dataclass
class IdentifiedGap:
    """Represents a research gap identified during analysis."""
    description: str
    suggested_query: str
    identified_at: str
    searched: bool = False
    search_results_count: int = 0

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return {
            'description': self.description,
            'suggested_query': self.suggested_query,
            'identified_at': self.identified_at,
            'searched': self.searched,
            'search_results_count': self.search_results_count,
        }


@dataclass
class ResearchSession:
    """Complete research session with all steps and final synthesis."""
    original_question: str
    sub_queries: List[str]
    steps: List[ResearchStep]
    synthesis: str
    total_chunks: int
    unique_documents: int
    started_at: str
    completed_at: str
    iterations: int
    identified_gaps: List[IdentifiedGap] = field(default_factory=list)
    auto_fill_gaps: bool = False  # If True, search gaps automatically; if False, collect for review
    cost_summary: Dict[str, Any] = field(default_factory=dict)  # API cost tracking
    budget_exceeded: bool = False  # True if session was cut short due to budget


# =============================================================================
# COST TRACKING
# =============================================================================

# Token pricing (per 1M tokens) - as of 2024
MODEL_PRICING = {
    # GPT-4o
    'gpt-4o': {'input': 2.50, 'output': 10.00},
    'gpt-4o-2024-11-20': {'input': 2.50, 'output': 10.00},
    # GPT-4o-mini
    'gpt-4o-mini': {'input': 0.15, 'output': 0.60},
    'gpt-4o-mini-2024-07-18': {'input': 0.15, 'output': 0.60},
    # GPT-4 Turbo
    'gpt-4-turbo': {'input': 10.00, 'output': 30.00},
    'gpt-4-turbo-preview': {'input': 10.00, 'output': 30.00},
    # GPT-3.5
    'gpt-3.5-turbo': {'input': 0.50, 'output': 1.50},
    'gpt-3.5-turbo-0125': {'input': 0.50, 'output': 1.50},
    # Local models (no cost)
    'default': {'input': 0.0, 'output': 0.0},
}


@dataclass
class CostTracker:
    """
    Tracks API usage and estimated costs.

    Provides safety rails to prevent runaway costs during research sessions.
    """
    total_input_tokens: int = 0
    total_output_tokens: int = 0
    total_requests: int = 0
    estimated_cost_usd: float = 0.0
    budget_limit_usd: float = 1.0  # Default $1 limit
    model_name: str = 'gpt-4o-mini'

    def add_usage(self, input_tokens: int, output_tokens: int) -> None:
        """Record token usage from an API call."""
        self.total_input_tokens += input_tokens
        self.total_output_tokens += output_tokens
        self.total_requests += 1

        # Calculate cost
        pricing = MODEL_PRICING.get(self.model_name, MODEL_PRICING['default'])
        input_cost = (input_tokens / 1_000_000) * pricing['input']
        output_cost = (output_tokens / 1_000_000) * pricing['output']
        self.estimated_cost_usd += input_cost + output_cost

    def check_budget(self) -> Tuple[bool, str]:
        """
        Check if we're within budget.

        Returns:
            Tuple of (within_budget, warning_message)
        """
        if self.estimated_cost_usd >= self.budget_limit_usd:
            return False, f"Budget limit reached: ${self.estimated_cost_usd:.4f} >= ${self.budget_limit_usd:.2f}"

        # Warning at 80% of budget
        if self.estimated_cost_usd >= self.budget_limit_usd * 0.8:
            return True, f"Warning: Approaching budget limit ({self.estimated_cost_usd:.4f}/${self.budget_limit_usd:.2f})"

        return True, ""

    def get_summary(self) -> Dict[str, Any]:
        """Get usage summary."""
        return {
            'total_requests': self.total_requests,
            'total_input_tokens': self.total_input_tokens,
            'total_output_tokens': self.total_output_tokens,
            'total_tokens': self.total_input_tokens + self.total_output_tokens,
            'estimated_cost_usd': round(self.estimated_cost_usd, 6),
            'budget_limit_usd': self.budget_limit_usd,
            'budget_remaining_usd': round(self.budget_limit_usd - self.estimated_cost_usd, 6),
            'budget_used_percent': round((self.estimated_cost_usd / self.budget_limit_usd) * 100, 1) if self.budget_limit_usd > 0 else 0,
            'model': self.model_name,
        }

    def format_status(self) -> str:
        """Format status for display."""
        summary = self.get_summary()
        return (
            f"Cost: ${summary['estimated_cost_usd']:.4f} / ${summary['budget_limit_usd']:.2f} "
            f"({summary['budget_used_percent']}%) | "
            f"Tokens: {summary['total_tokens']:,} | "
            f"Requests: {summary['total_requests']}"
        )


class BudgetExceededError(Exception):
    """Raised when the cost budget is exceeded."""
    pass


# =============================================================================
# LLM INTERFACE
# =============================================================================

class LLMInterface:
    """Interface for LLM calls - supports OpenAI and local models with cost tracking."""

    def __init__(self, budget_limit_usd: float = 1.0):
        """
        Initialize LLM interface with cost tracking.

        Args:
            budget_limit_usd: Maximum budget in USD (default: $1.00)
        """
        self.mode = INTELLIGENCE_MODE
        self.cost_tracker = CostTracker(budget_limit_usd=budget_limit_usd)

        if self.mode == 'cloud' and OPENAI_ENABLED:
            from openai import OpenAI
            self.client = OpenAI(api_key=OPENAI_API_KEY)
            self.model = CLOUD_MODELS['chat']
            self.cost_tracker.model_name = self.model
            logger.info(f"Using OpenAI API with model: {self.model}")
            logger.info(f"Budget limit: ${budget_limit_usd:.2f}")

        elif self.mode == 'local':
            from openai import OpenAI
            self.client = OpenAI(
                base_url=LOCAL_LLM_ENDPOINT,
                api_key="not-needed"
            )
            self.model = LOCAL_LLM_MODEL
            self.cost_tracker.model_name = self.model
            logger.info(f"Using local LLM at {LOCAL_LLM_ENDPOINT} with model: {self.model}")
            logger.info("Local mode: No API costs")

        else:
            self.client = None
            self.model = None
            logger.warning("No LLM available. Agent will use basic heuristics.")

    def chat(self, messages: List[Dict], temperature: float = 0.3) -> str:
        """
        Send a chat completion request with cost tracking.

        Raises:
            BudgetExceededError: If the budget limit is reached
        """
        if self.client is None:
            return self._fallback_response(messages)

        # Check budget before making call
        within_budget, warning = self.cost_tracker.check_budget()
        if not within_budget:
            raise BudgetExceededError(warning)
        if warning:
            logger.warning(warning)

        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                temperature=temperature,
                max_tokens=2000
            )

            # Track usage if available
            if hasattr(response, 'usage') and response.usage:
                self.cost_tracker.add_usage(
                    input_tokens=response.usage.prompt_tokens,
                    output_tokens=response.usage.completion_tokens
                )
                logger.debug(f"API call: {response.usage.prompt_tokens} in, "
                           f"{response.usage.completion_tokens} out | "
                           f"{self.cost_tracker.format_status()}")

            return response.choices[0].message.content.strip()

        except BudgetExceededError:
            raise
        except Exception as e:
            logger.error(f"LLM call failed: {e}")
            return self._fallback_response(messages)

    def _fallback_response(self, messages: List[Dict]) -> str:
        """Fallback when no LLM is available."""
        # Extract the user's last message
        user_msg = messages[-1].get('content', '') if messages else ''

        # Very basic heuristic responses
        if 'plan' in user_msg.lower() or 'sub-queries' in user_msg.lower():
            return '["search for the main topic"]'
        elif 'sufficient' in user_msg.lower() or 'enough' in user_msg.lower():
            return 'YES'
        elif 'synthesize' in user_msg.lower() or 'summarize' in user_msg.lower():
            return 'Based on the gathered research, the topic covers several key areas.'
        else:
            return 'Continue with standard search.'

    def get_cost_summary(self) -> Dict[str, Any]:
        """Get the current cost tracking summary."""
        return self.cost_tracker.get_summary()


# =============================================================================
# RESEARCH AGENT
# =============================================================================

class ResearchAgent:
    """
    Autonomous agent for iterative research tasks.

    The agent follows a Plan-Search-Analyze-Iterate loop:
    1. Plan: Break the question into searchable sub-queries
    2. Search: Execute each query with hybrid search + re-ranking
    3. Analyze: Evaluate if results are sufficient
    4. Iterate: If gaps found, generate new queries and repeat
    5. Synthesize: Combine all findings into a report
    """

    def __init__(
        self,
        max_iterations: int = 5,
        min_results_per_query: int = 3,
        use_graphrag: bool = True,
        use_rerank: bool = True,
        auto_fill_gaps: bool = False,
        budget_limit_usd: float = 1.0,
        strict_library_only: bool = False
    ):
        """
        Initialize the research agent.

        Args:
            max_iterations: Maximum search iterations before stopping
            min_results_per_query: Minimum results needed per sub-query
            use_graphrag: Whether to use knowledge graph traversal
            use_rerank: Whether to use cross-encoder re-ranking
            auto_fill_gaps: If True, automatically search for gaps; if False, collect gaps for user review
            budget_limit_usd: Maximum budget for OpenAI API calls (default: $1.00)
            strict_library_only: If True, never use web search - admit gaps instead
        """
        self.max_iterations = max_iterations
        self.min_results_per_query = min_results_per_query
        self.use_graphrag = use_graphrag
        self.use_rerank = use_rerank
        self.auto_fill_gaps = auto_fill_gaps
        self.budget_limit_usd = budget_limit_usd
        self.strict_library_only = strict_library_only

        # Override auto_fill_gaps if strict_library_only is set
        if strict_library_only:
            self.auto_fill_gaps = False
            logger.info("Strict library-only mode: Web search disabled, gaps will be reported rather than filled")

        self.llm = LLMInterface(budget_limit_usd=budget_limit_usd)
        self.all_results: Dict[str, SearchResult] = {}  # chunk_id -> result
        self.search_history: List[str] = []
        self.identified_gaps: List[IdentifiedGap] = []  # Collected gaps for review

    def plan_research(self, question: str) -> List[str]:
        """
        Break down a research question into sub-queries.

        Args:
            question: The main research question

        Returns:
            List of sub-queries to search
        """
        prompt = f"""You are a research planning assistant. Break down this research question into 2-5 specific search queries that together will answer the question comprehensively.

Research Question: {question}

Return ONLY a JSON array of search query strings. Each query should be:
- Specific and searchable
- Focused on one aspect of the question
- Worded to find relevant passages

Example output format:
["query 1", "query 2", "query 3"]

Your response (JSON array only):"""

        messages = [
            {"role": "system", "content": "You are a research planning assistant. Respond only with valid JSON."},
            {"role": "user", "content": prompt}
        ]

        response = self.llm.chat(messages)

        # Parse JSON response
        try:
            # Clean up response if needed
            response = response.strip()
            if response.startswith('```'):
                response = response.split('\n', 1)[1]
                response = response.rsplit('```', 1)[0]

            sub_queries = json.loads(response)
            if isinstance(sub_queries, list):
                return sub_queries[:5]  # Limit to 5 queries
        except json.JSONDecodeError:
            logger.warning(f"Could not parse sub-queries, using original question")

        # Fallback: use the original question
        return [question]

    def execute_search(self, query: str) -> List[SearchResult]:
        """
        Execute a search query using hybrid search with optional GraphRAG.

        Args:
            query: The search query

        Returns:
            List of SearchResult objects
        """
        results = []

        try:
            if self.use_graphrag:
                # Use GraphRAG for multi-hop reasoning
                graphrag_results = graphrag_search(
                    query=query,
                    limit=10,
                    hop_depth=2,
                    include_direct_search=True
                )

                for chunk in graphrag_results.get('results', []):
                    result = SearchResult(
                        chunk_id=chunk.get('chunk_id', ''),
                        document_id=chunk.get('document_id', ''),
                        title=chunk.get('title', 'Unknown'),
                        chunk_text=chunk.get('chunk_text', ''),
                        score=chunk.get('rerank_score', chunk.get('rrf_score', 0)),
                        retrieval_method=chunk.get('retrieval_method', 'graphrag'),
                        concepts=chunk.get('concept_names', []) if isinstance(chunk.get('concept_names'), list) else []
                    )
                    results.append(result)
            else:
                # Use hybrid search with re-ranking
                hybrid_results = hybrid_search_with_rerank(
                    query_text=query,
                    limit=10,
                    use_rerank=self.use_rerank
                )

                for chunk in hybrid_results:
                    result = SearchResult(
                        chunk_id=chunk.get('chunk_id', ''),
                        document_id=chunk.get('document_id', ''),
                        title=chunk.get('title', 'Unknown'),
                        chunk_text=chunk.get('chunk_text', ''),
                        score=chunk.get('rerank_score', chunk.get('rrf_score', chunk.get('similarity', 0))),
                        retrieval_method='hybrid'
                    )
                    results.append(result)

        except Exception as e:
            logger.error(f"Search failed for query '{query}': {e}")

        self.search_history.append(query)
        return results

    def analyze_sufficiency(
        self,
        question: str,
        sub_queries: List[str],
        results_by_query: Dict[str, List[SearchResult]]
    ) -> Tuple[bool, List[Dict[str, str]]]:
        """
        Analyze if the gathered results are sufficient to answer the question.

        Args:
            question: Original research question
            sub_queries: List of sub-queries planned
            results_by_query: Results gathered for each query

        Returns:
            Tuple of (is_sufficient, list_of_gaps)
            Each gap is a dict with 'description' and 'suggested_query' keys
        """
        # Build context summary
        context_summary = []
        for query, results in results_by_query.items():
            context_summary.append(f"Query: {query}")
            context_summary.append(f"  Results: {len(results)} chunks found")
            if results:
                titles = set(r.title for r in results[:5])
                context_summary.append(f"  Sources: {', '.join(titles)}")

        prompt = f"""You are evaluating research completeness.

Original Question: {question}

Research Progress:
{chr(10).join(context_summary)}

Previous Searches: {', '.join(self.search_history[-10:])}

Analyze:
1. Have we gathered enough information to answer the original question?
2. Are there any obvious gaps or missing perspectives?

If the research is SUFFICIENT, respond with exactly:
SUFFICIENT

If there are gaps, respond with a JSON array of gaps. Each gap should have:
- "description": What information is missing
- "suggested_query": A specific search query to fill this gap

Example response for gaps:
[
  {{"description": "Missing information about X", "suggested_query": "search query for X"}},
  {{"description": "Need more on Y perspective", "suggested_query": "search query for Y"}}
]

Your response (either SUFFICIENT or JSON array):"""

        messages = [
            {"role": "system", "content": "You are a research completeness evaluator. Respond only with SUFFICIENT or a valid JSON array of gaps."},
            {"role": "user", "content": prompt}
        ]

        response = self.llm.chat(messages, temperature=0.2)

        # Check if sufficient
        if 'SUFFICIENT' in response.upper() and '[' not in response:
            return True, []

        # Try to parse gaps
        gaps = []
        try:
            # Clean up response
            response = response.strip()
            if response.startswith('```'):
                response = response.split('\n', 1)[1]
                response = response.rsplit('```', 1)[0]

            # Find JSON array
            start = response.find('[')
            end = response.rfind(']') + 1
            if start >= 0 and end > start:
                json_str = response[start:end]
                parsed_gaps = json.loads(json_str)
                if isinstance(parsed_gaps, list):
                    for g in parsed_gaps:
                        if isinstance(g, dict) and 'description' in g:
                            gaps.append({
                                'description': g.get('description', ''),
                                'suggested_query': g.get('suggested_query', g.get('description', '')[:100]),
                            })
        except json.JSONDecodeError:
            # Fallback: treat response as single gap description
            if response and not response.upper().startswith('SUFFICIENT'):
                gap_text = response.replace('NO - ', '').replace('NO-', '').strip()
                if gap_text:
                    gaps.append({
                        'description': gap_text,
                        'suggested_query': gap_text.split('\n')[0][:100],
                    })

        return len(gaps) == 0, gaps

    def _add_identified_gap(self, description: str, suggested_query: str) -> None:
        """Add a gap to the identified gaps list."""
        gap = IdentifiedGap(
            description=description,
            suggested_query=suggested_query,
            identified_at=datetime.now().isoformat(),
        )
        self.identified_gaps.append(gap)
        logger.info(f"Gap identified: {description[:80]}...")

    def synthesize_findings(
        self,
        question: str,
        all_results: List[SearchResult]
    ) -> str:
        """
        Synthesize all findings into a coherent research report.

        Args:
            question: Original research question
            all_results: All gathered search results

        Returns:
            Synthesized research report
        """
        # Prepare context from results
        context_chunks = []
        for i, result in enumerate(all_results[:20], 1):  # Limit to top 20
            context_chunks.append(f"""
[{i}] Source: {result.title}
{result.chunk_text[:500]}...
""")

        prompt = f"""You are a research synthesizer. Based on the gathered research materials, write a comprehensive response to the research question.

Research Question: {question}

Gathered Materials:
{''.join(context_chunks)}

Instructions:
1. Synthesize the key findings from across all sources
2. Organize the response logically
3. Note any contradictions or different perspectives
4. Cite sources by number [1], [2], etc.
5. Highlight any remaining gaps or areas for further research

Write a well-structured research summary:"""

        messages = [
            {"role": "system", "content": "You are an expert research synthesizer who creates clear, well-organized summaries."},
            {"role": "user", "content": prompt}
        ]

        return self.llm.chat(messages, temperature=0.4)

    def get_cost_summary(self) -> Dict[str, Any]:
        """Get the current cost tracking summary."""
        return self.llm.get_cost_summary()

    def research(self, question: str) -> ResearchSession:
        """
        Execute a complete research session.

        Args:
            question: The research question to investigate

        Returns:
            ResearchSession with all steps and synthesis
        """
        started_at = datetime.now().isoformat()
        steps = []
        results_by_query: Dict[str, List[SearchResult]] = {}
        self.identified_gaps = []  # Reset gaps for new session
        budget_exceeded = False
        synthesis = ""
        sub_queries = []

        logger.info(f"Starting research: {question}")
        logger.info(f"Auto-fill gaps: {self.auto_fill_gaps}")
        logger.info(f"Budget limit: ${self.budget_limit_usd:.2f}")

        try:
            # Step 1: Plan sub-queries
            sub_queries = self.plan_research(question)
            logger.info(f"Planned {len(sub_queries)} sub-queries: {sub_queries}")

            # Step 2-4: Search, Analyze, Iterate
            iteration = 0
            queries_to_search = sub_queries.copy()

            while iteration < self.max_iterations and queries_to_search:
                iteration += 1
                current_query = queries_to_search.pop(0)

                logger.info(f"Iteration {iteration}: Searching '{current_query}'")

                # Execute search
                results = self.execute_search(current_query)

                # Store results
                for result in results:
                    self.all_results[result.chunk_id] = result
                results_by_query[current_query] = results

                # Create step record
                step = ResearchStep(
                    query=current_query,
                    reasoning=f"Searching for information about: {current_query}",
                    results=results,
                    timestamp=datetime.now().isoformat(),
                    iteration=iteration
                )
                steps.append(step)

                logger.info(f"  Found {len(results)} results")

                # Analyze sufficiency if all planned queries done
                if not queries_to_search:
                    is_sufficient, gaps = self.analyze_sufficiency(
                        question, sub_queries, results_by_query
                    )

                    if not is_sufficient and gaps:
                        if self.auto_fill_gaps and iteration < self.max_iterations:
                            # Auto-fill mode: add first gap query to search queue
                            first_gap = gaps[0]
                            new_query = first_gap['suggested_query']
                            if new_query and new_query not in self.search_history:
                                queries_to_search.append(new_query)
                                logger.info(f"  Auto-filling gap: {first_gap['description'][:80]}...")

                            # Store remaining gaps for later review
                            for gap in gaps[1:]:
                                self._add_identified_gap(gap['description'], gap['suggested_query'])
                        else:
                            # Collect mode: store all gaps for user review
                            for gap in gaps:
                                self._add_identified_gap(gap['description'], gap['suggested_query'])
                            logger.info(f"  Collected {len(gaps)} gaps for review")
                    else:
                        logger.info("  Research deemed sufficient")

            # Step 5: Synthesize
            all_results_list = list(self.all_results.values())
            synthesis = self.synthesize_findings(question, all_results_list)

        except BudgetExceededError as e:
            budget_exceeded = True
            logger.warning(f"Budget exceeded: {e}")
            logger.warning("Research stopped early due to budget limit")

            # Create partial synthesis if we have results
            all_results_list = list(self.all_results.values())
            if all_results_list:
                synthesis = (
                    f"**Note: Research was stopped early due to budget limit.**\n\n"
                    f"Based on the {len(all_results_list)} chunks gathered before the limit was reached:\n\n"
                    "Please increase the budget (--budget) to complete the research."
                )
            else:
                synthesis = "Research could not be completed: Budget exceeded before gathering results."

        # Create session
        unique_docs = set(r.document_id for r in all_results_list)

        # Get final cost summary
        cost_summary = self.get_cost_summary()

        session = ResearchSession(
            original_question=question,
            sub_queries=sub_queries,
            steps=steps,
            synthesis=synthesis,
            total_chunks=len(all_results_list),
            unique_documents=len(unique_docs),
            started_at=started_at,
            completed_at=datetime.now().isoformat(),
            iterations=iteration,
            identified_gaps=self.identified_gaps.copy(),
            auto_fill_gaps=self.auto_fill_gaps,
            cost_summary=cost_summary,
            budget_exceeded=budget_exceeded,
        )

        # Log completion status
        if budget_exceeded:
            logger.warning(f"Research stopped: Budget exceeded (${cost_summary.get('estimated_cost_usd', 0):.4f})")
        elif self.identified_gaps:
            logger.info(f"Research complete with {len(self.identified_gaps)} gaps identified for review")
        else:
            logger.info(f"Research complete: {len(all_results_list)} chunks from {len(unique_docs)} documents")

        # Log cost summary
        logger.info(f"Cost summary: {cost_summary.get('total_requests', 0)} API calls, "
                   f"${cost_summary.get('estimated_cost_usd', 0):.4f} / "
                   f"${cost_summary.get('budget_limit_usd', 0):.2f}")

        return session


# =============================================================================
# OUTPUT FORMATTING
# =============================================================================

def format_markdown_report(session: ResearchSession) -> str:
    """Format research session as a markdown report."""
    lines = [
        f"# Research Report",
        f"",
        f"**Question:** {session.original_question}",
        f"",
        f"**Date:** {session.completed_at[:10]}",
        f"",
        f"**Sources:** {session.unique_documents} documents, {session.total_chunks} text chunks",
        f"",
    ]

    # Add budget warning if exceeded
    if session.budget_exceeded:
        lines.extend([
            f"**⚠️ Status:** Research stopped early - budget limit reached",
            f"",
        ])

    # Add cost summary if available
    if session.cost_summary:
        cost = session.cost_summary.get('estimated_cost_usd', 0)
        budget = session.cost_summary.get('budget_limit_usd', 0)
        tokens = session.cost_summary.get('total_tokens', 0)
        lines.extend([
            f"**API Cost:** ${cost:.4f} / ${budget:.2f} ({tokens:,} tokens)",
            f"",
        ])

    # Add gaps summary if any
    if session.identified_gaps:
        lines.extend([
            f"**Research Gaps:** {len(session.identified_gaps)} identified",
            f"",
        ])

    lines.extend([
        f"---",
        f"",
        f"## Summary",
        f"",
        session.synthesis,
        f"",
        f"---",
        f"",
    ])

    # Add identified gaps section
    if session.identified_gaps:
        lines.extend([
            f"## Research Gaps Identified",
            f"",
            f"The following gaps were identified during research analysis. Consider searching for additional sources to fill these gaps:",
            f"",
        ])
        for i, gap in enumerate(session.identified_gaps, 1):
            searched_marker = "✓" if gap.searched else "○"
            lines.extend([
                f"{i}. {searched_marker} **{gap.description}**",
                f"   - Suggested query: `{gap.suggested_query}`",
                f"",
            ])
        lines.extend([
            f"---",
            f"",
        ])

    lines.extend([
        f"## Research Process",
        f"",
        f"### Sub-queries Explored",
        f"",
    ])

    for i, query in enumerate(session.sub_queries, 1):
        lines.append(f"{i}. {query}")

    lines.extend([
        f"",
        f"### Search Steps",
        f"",
    ])

    for step in session.steps:
        lines.extend([
            f"**Iteration {step.iteration}:** `{step.query}`",
            f"- Results: {len(step.results)} chunks",
            f"- Time: {step.timestamp}",
            f"",
        ])

    lines.extend([
        f"---",
        f"",
        f"## Sources Referenced",
        f"",
    ])

    # List unique sources
    sources = {}
    for step in session.steps:
        for result in step.results:
            if result.document_id not in sources:
                sources[result.document_id] = result.title

    for doc_id, title in sorted(sources.items(), key=lambda x: x[1]):
        lines.append(f"- **{title}** ({doc_id})")

    return '\n'.join(lines)


def format_json_report(session: ResearchSession) -> str:
    """Format research session as JSON."""
    data = {
        'question': session.original_question,
        'sub_queries': session.sub_queries,
        'synthesis': session.synthesis,
        'stats': {
            'total_chunks': session.total_chunks,
            'unique_documents': session.unique_documents,
            'iterations': session.iterations,
            'gaps_identified': len(session.identified_gaps),
        },
        'cost': session.cost_summary,
        'budget_exceeded': session.budget_exceeded,
        'timeline': {
            'started': session.started_at,
            'completed': session.completed_at
        },
        'steps': [
            {
                'iteration': step.iteration,
                'query': step.query,
                'result_count': len(step.results),
                'top_sources': list(set(r.title for r in step.results[:5]))
            }
            for step in session.steps
        ],
        'identified_gaps': [gap.to_dict() for gap in session.identified_gaps],
        'auto_fill_gaps': session.auto_fill_gaps,
    }
    return json.dumps(data, indent=2)


# =============================================================================
# MAIN
# =============================================================================

def main():
    parser = argparse.ArgumentParser(
        description='Autonomous Research Agent for iterative research tasks',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Single question
  %(prog)s "What is the nature of the etheric body?"

  # With output file
  %(prog)s "Compare Steiner and Jung on dreams" --output report.md

  # With higher budget for complex research
  %(prog)s "Complex multi-part question" --budget 5.00

  # Batch mode
  %(prog)s --batch questions.txt --output-dir reports/

  # Adjust parameters
  %(prog)s "Research question" --max-iterations 7 --no-graphrag

Cost Tracking:
  The agent tracks API costs and enforces budget limits to prevent runaway spending.
  Default budget is $1.00 per research session.
  Use --budget to increase for complex research tasks.
        """
    )

    parser.add_argument(
        'question',
        nargs='?',
        help='The research question to investigate'
    )
    parser.add_argument(
        '--batch',
        metavar='FILE',
        help='File with questions (one per line)'
    )
    parser.add_argument(
        '--output', '-o',
        metavar='FILE',
        help='Output file path'
    )
    parser.add_argument(
        '--output-dir',
        metavar='DIR',
        help='Output directory for batch mode'
    )
    parser.add_argument(
        '--format', '-f',
        choices=['markdown', 'json'],
        default='markdown',
        help='Output format (default: markdown)'
    )
    parser.add_argument(
        '--max-iterations',
        type=int,
        default=5,
        help='Maximum research iterations (default: 5)'
    )
    parser.add_argument(
        '--budget',
        type=float,
        default=1.0,
        metavar='USD',
        help='Maximum API budget in USD (default: $1.00)'
    )
    parser.add_argument(
        '--no-graphrag',
        action='store_true',
        help='Disable GraphRAG (use hybrid search only)'
    )
    parser.add_argument(
        '--no-rerank',
        action='store_true',
        help='Disable cross-encoder re-ranking'
    )
    parser.add_argument(
        '--strict-library-only',
        action='store_true',
        help='Disable web search - admit ignorance rather than searching external sources (maintains library purity)'
    )
    parser.add_argument(
        '--verbose', '-v',
        action='store_true',
        help='Verbose output'
    )

    args = parser.parse_args()

    if args.verbose:
        logging.getLogger().setLevel(logging.DEBUG)

    # Validate arguments
    if not args.question and not args.batch:
        parser.error("Either a question or --batch file is required")

    # Initialize agent
    agent = ResearchAgent(
        max_iterations=args.max_iterations,
        use_graphrag=not args.no_graphrag,
        use_rerank=not args.no_rerank,
        budget_limit_usd=args.budget,
        strict_library_only=args.strict_library_only
    )

    # Process questions
    questions = []
    if args.batch:
        batch_path = Path(args.batch)
        if not batch_path.exists():
            print(f"ERROR: Batch file not found: {args.batch}")
            sys.exit(1)

        with open(batch_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    questions.append(line)
    else:
        questions = [args.question]

    # Process each question
    for i, question in enumerate(questions, 1):
        print(f"\n{'='*60}")
        print(f"Research {i}/{len(questions)}: {question[:50]}...")
        print('='*60)

        # Execute research
        session = agent.research(question)

        # Format output
        if args.format == 'json':
            report = format_json_report(session)
            ext = '.json'
        else:
            report = format_markdown_report(session)
            ext = '.md'

        # Determine output path
        if args.output:
            output_path = Path(args.output)
        elif args.output_dir:
            output_dir = Path(args.output_dir)
            output_dir.mkdir(parents=True, exist_ok=True)
            safe_name = "".join(c if c.isalnum() else '_' for c in question[:40])
            output_path = output_dir / f"{safe_name}{ext}"
        else:
            output_path = None

        # Write or print output
        if output_path:
            with open(output_path, 'w', encoding='utf-8') as f:
                f.write(report)
            print(f"\nReport saved to: {output_path}")
        else:
            print("\n" + report)

        print(f"\nStats: {session.total_chunks} chunks from {session.unique_documents} documents")
        print(f"Iterations: {session.iterations}")

        # Display cost summary
        if session.cost_summary:
            cost = session.cost_summary.get('estimated_cost_usd', 0)
            budget = session.cost_summary.get('budget_limit_usd', 0)
            tokens = session.cost_summary.get('total_tokens', 0)
            requests = session.cost_summary.get('total_requests', 0)
            print(f"API Cost: ${cost:.4f} / ${budget:.2f} ({tokens:,} tokens, {requests} requests)")

            if session.budget_exceeded:
                print("\n⚠️  BUDGET EXCEEDED - Research stopped early")
                print(f"   Increase budget with: --budget {budget * 2:.2f}")


if __name__ == '__main__':
    main()
