"""
Database utility functions for the Research Development Framework.
Provides connection management and common database operations.

v2.1 Additions:
- Cross-Encoder Re-ranking for improved search relevance
- Hybrid search with Reciprocal Rank Fusion (RRF)
- GraphRAG retrieval for multi-hop reasoning
"""

import psycopg2
from psycopg2.extras import RealDictCursor, execute_values
from contextlib import contextmanager
from typing import Optional, List, Dict, Any, Generator, Tuple
import logging

from config import DB_CONFIG, get_db_connection_string, OPENAI_ENABLED, OPENAI_API_KEY

logger = logging.getLogger(__name__)

# =============================================================================
# CROSS-ENCODER RE-RANKING (Optimization #1)
# =============================================================================
# Lazy-load the cross-encoder model to avoid import overhead
_cross_encoder = None
_cross_encoder_available = None

def _get_cross_encoder():
    """Lazy-load the cross-encoder model."""
    global _cross_encoder, _cross_encoder_available

    if _cross_encoder_available is False:
        return None

    if _cross_encoder is None:
        try:
            from sentence_transformers import CrossEncoder
            # BGE reranker is excellent for English text; falls back to ms-marco
            try:
                _cross_encoder = CrossEncoder('BAAI/bge-reranker-base', max_length=512)
                logger.info("Loaded BGE reranker model for re-ranking")
            except Exception:
                _cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
                logger.info("Loaded MS-MARCO MiniLM model for re-ranking")
            _cross_encoder_available = True
        except ImportError:
            logger.warning("sentence-transformers not installed. Re-ranking unavailable.")
            logger.warning("Install with: pip install sentence-transformers")
            _cross_encoder_available = False
            return None
        except Exception as e:
            logger.warning(f"Failed to load cross-encoder model: {e}")
            _cross_encoder_available = False
            return None

    return _cross_encoder


def rerank_results(
    query: str,
    results: List[Dict],
    top_k: int = 10,
    text_field: str = 'chunk_text'
) -> List[Dict]:
    """
    Re-rank search results using a Cross-Encoder model.

    Cross-encoders read the query and each document together, providing
    much more accurate relevance scores than bi-encoder similarity.

    Args:
        query: The original search query
        results: List of result dictionaries from semantic/keyword search
        top_k: Number of top results to return after re-ranking
        text_field: Field name containing the text to score

    Returns:
        List of results re-ordered by cross-encoder relevance score,
        with 'rerank_score' field added
    """
    if not results:
        return results

    encoder = _get_cross_encoder()
    if encoder is None:
        logger.debug("Cross-encoder unavailable, returning original order")
        return results[:top_k]

    # Prepare query-document pairs
    pairs = [(query, r.get(text_field, '')) for r in results]

    # Get cross-encoder scores
    try:
        scores = encoder.predict(pairs, show_progress_bar=False)
    except Exception as e:
        logger.error(f"Re-ranking failed: {e}")
        return results[:top_k]

    # Attach scores and sort
    for i, result in enumerate(results):
        result['rerank_score'] = float(scores[i])

    # Sort by rerank score (higher is better)
    reranked = sorted(results, key=lambda x: x.get('rerank_score', 0), reverse=True)

    return reranked[:top_k]


def hybrid_search_with_rerank(
    query_text: str,
    query_embedding: List[float] = None,
    limit: int = 10,
    initial_fetch: int = 50,
    document_id: str = None,
    min_quality: str = 'fair',
    use_rerank: bool = True
) -> List[Dict]:
    """
    Advanced hybrid search combining keyword + semantic search with re-ranking.

    Pipeline:
    1. Fetch top-N results from both keyword and semantic search
    2. Merge using Reciprocal Rank Fusion (RRF)
    3. Re-rank top candidates with cross-encoder
    4. Return final top-K results

    Args:
        query_text: Search query string
        query_embedding: Pre-computed query embedding (optional, will generate if None)
        limit: Final number of results to return
        initial_fetch: Number of results to fetch from each method before fusion
        document_id: Optional document filter
        min_quality: Minimum quality threshold
        use_rerank: Whether to apply cross-encoder re-ranking

    Returns:
        List of search results with 'rrf_score' and optionally 'rerank_score'
    """
    from config import OPENAI_ENABLED

    # Get keyword results
    keyword_results = keyword_search(
        query_text=query_text,
        limit=initial_fetch,
        document_id=document_id,
        min_quality=min_quality
    )

    # Get semantic results if embedding available
    semantic_results = []
    if query_embedding is not None:
        semantic_results = semantic_search(
            query_embedding=query_embedding,
            limit=initial_fetch,
            document_id=document_id,
            min_quality=min_quality
        )
    elif OPENAI_ENABLED:
        # Generate embedding on the fly
        try:
            embedding = generate_query_embedding(query_text)
            if embedding:
                semantic_results = semantic_search(
                    query_embedding=embedding,
                    limit=initial_fetch,
                    document_id=document_id,
                    min_quality=min_quality
                )
        except Exception as e:
            logger.warning(f"Could not generate embedding: {e}")

    # Reciprocal Rank Fusion
    rrf_k = 60  # Standard RRF constant
    chunk_scores = {}  # chunk_id -> (rrf_score, result_dict)

    # Score keyword results
    for rank, result in enumerate(keyword_results, start=1):
        chunk_id = result['chunk_id']
        rrf_score = 1.0 / (rrf_k + rank)
        if chunk_id in chunk_scores:
            chunk_scores[chunk_id] = (
                chunk_scores[chunk_id][0] + rrf_score,
                chunk_scores[chunk_id][1]  # Keep existing dict
            )
        else:
            result['keyword_rank'] = rank
            chunk_scores[chunk_id] = (rrf_score, result)

    # Score semantic results
    for rank, result in enumerate(semantic_results, start=1):
        chunk_id = result['chunk_id']
        rrf_score = 1.0 / (rrf_k + rank)
        if chunk_id in chunk_scores:
            chunk_scores[chunk_id] = (
                chunk_scores[chunk_id][0] + rrf_score,
                chunk_scores[chunk_id][1]
            )
            chunk_scores[chunk_id][1]['semantic_rank'] = rank
        else:
            result['semantic_rank'] = rank
            chunk_scores[chunk_id] = (rrf_score, result)

    # Sort by RRF score
    fused_results = []
    for chunk_id, (rrf_score, result) in chunk_scores.items():
        result['rrf_score'] = rrf_score
        fused_results.append(result)

    fused_results.sort(key=lambda x: x['rrf_score'], reverse=True)

    # Re-rank if enabled
    if use_rerank and fused_results:
        # Take more candidates for re-ranking
        candidates = fused_results[:min(initial_fetch, len(fused_results))]
        return rerank_results(query_text, candidates, top_k=limit)

    return fused_results[:limit]


def generate_query_embedding(query_text: str) -> Optional[List[float]]:
    """Generate embedding for a query using OpenAI API."""
    if not OPENAI_ENABLED:
        return None

    try:
        from openai import OpenAI
        from config import EMBEDDING_MODEL

        client = OpenAI(api_key=OPENAI_API_KEY)
        response = client.embeddings.create(
            model=EMBEDDING_MODEL,
            input=query_text
        )
        return response.data[0].embedding
    except Exception as e:
        logger.error(f"Failed to generate query embedding: {e}")
        return None


@contextmanager
def get_db_connection() -> Generator:
    """
    Context manager for database connections.
    Automatically handles commit/rollback and connection cleanup.

    Usage:
        with get_db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute("SELECT * FROM documents")
    """
    conn = None
    try:
        conn = psycopg2.connect(**DB_CONFIG)
        yield conn
        conn.commit()
    except Exception as e:
        if conn:
            conn.rollback()
        logger.error(f"Database error: {e}")
        raise
    finally:
        if conn:
            conn.close()


@contextmanager
def get_dict_cursor() -> Generator:
    """
    Context manager that provides a connection and dictionary cursor.
    Returns rows as dictionaries instead of tuples.

    Usage:
        with get_dict_cursor() as (conn, cur):
            cur.execute("SELECT * FROM documents WHERE document_id = %s", (doc_id,))
            row = cur.fetchone()  # Returns dict
    """
    with get_db_connection() as conn:
        with conn.cursor(cursor_factory=RealDictCursor) as cur:
            yield conn, cur


def execute_query(query: str, params: tuple = None, fetch: str = 'all') -> Any:
    """
    Execute a query and return results.

    Args:
        query: SQL query string
        params: Query parameters
        fetch: 'all', 'one', or 'none'

    Returns:
        Query results based on fetch parameter
    """
    with get_dict_cursor() as (conn, cur):
        cur.execute(query, params)
        if fetch == 'all':
            return cur.fetchall()
        elif fetch == 'one':
            return cur.fetchone()
        else:
            return None


def insert_document(document_data: Dict[str, Any]) -> str:
    """
    Insert a new document record.

    Args:
        document_data: Dictionary with document fields

    Returns:
        The document_id of the inserted record
    """
    columns = list(document_data.keys())
    values = list(document_data.values())
    placeholders = ', '.join(['%s'] * len(columns))
    column_names = ', '.join(columns)

    query = f"""
        INSERT INTO documents ({column_names})
        VALUES ({placeholders})
        ON CONFLICT (document_id) DO UPDATE SET
            updated_at = CURRENT_TIMESTAMP
        RETURNING document_id
    """

    with get_dict_cursor() as (conn, cur):
        cur.execute(query, values)
        result = cur.fetchone()
        return result['document_id']


def insert_chunk(chunk_data: Dict[str, Any]) -> str:
    """
    Insert a new chunk record.

    Args:
        chunk_data: Dictionary with chunk fields

    Returns:
        The chunk_id of the inserted record
    """
    columns = list(chunk_data.keys())
    values = list(chunk_data.values())
    placeholders = ', '.join(['%s'] * len(columns))
    column_names = ', '.join(columns)

    query = f"""
        INSERT INTO chunks ({column_names})
        VALUES ({placeholders})
        ON CONFLICT (chunk_id) DO UPDATE SET
            chunk_text = EXCLUDED.chunk_text,
            chunk_tokens = EXCLUDED.chunk_tokens
        RETURNING chunk_id
    """

    with get_dict_cursor() as (conn, cur):
        cur.execute(query, values)
        result = cur.fetchone()
        return result['chunk_id']


def batch_insert_chunks(chunks: List[Dict[str, Any]]) -> int:
    """
    Batch insert multiple chunks efficiently.

    Args:
        chunks: List of chunk dictionaries

    Returns:
        Number of chunks inserted
    """
    if not chunks:
        return 0

    columns = list(chunks[0].keys())
    column_names = ', '.join(columns)

    query = f"""
        INSERT INTO chunks ({column_names})
        VALUES %s
        ON CONFLICT (chunk_id) DO UPDATE SET
            chunk_text = EXCLUDED.chunk_text,
            chunk_tokens = EXCLUDED.chunk_tokens,
            page_start = EXCLUDED.page_start,
            page_end = EXCLUDED.page_end
    """

    values = [tuple(chunk[col] for col in columns) for chunk in chunks]

    with get_db_connection() as conn:
        with conn.cursor() as cur:
            execute_values(cur, query, values)
            return len(chunks)


def update_embedding(chunk_id: str, embedding: List[float]) -> None:
    """
    Update the embedding vector for a chunk.

    Args:
        chunk_id: The chunk identifier
        embedding: List of floats representing the embedding vector
    """
    query = """
        UPDATE chunks
        SET embedding = %s::vector
        WHERE chunk_id = %s
    """

    with get_db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(query, (embedding, chunk_id))


def batch_update_embeddings(embeddings: List[tuple]) -> int:
    """
    Batch update embeddings for multiple chunks.

    Args:
        embeddings: List of (chunk_id, embedding_vector) tuples

    Returns:
        Number of embeddings updated
    """
    if not embeddings:
        return 0

    query = """
        UPDATE chunks
        SET embedding = data.embedding::vector
        FROM (VALUES %s) AS data(chunk_id, embedding)
        WHERE chunks.chunk_id = data.chunk_id
    """

    with get_db_connection() as conn:
        with conn.cursor() as cur:
            execute_values(cur, query, embeddings)
            return len(embeddings)


def update_document_status(document_id: str, status: str, error_message: str = None) -> None:
    """
    Update the processing status of a document.

    Args:
        document_id: The document identifier
        status: New status (pending, processing, completed, failed)
        error_message: Optional error message for failed status
    """
    query = """
        UPDATE documents
        SET processing_status = %s,
            notes = CASE WHEN %s IS NOT NULL THEN %s ELSE notes END,
            updated_at = CURRENT_TIMESTAMP
        WHERE document_id = %s
    """

    with get_db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(query, (status, error_message, error_message, document_id))


def get_document(document_id: str) -> Optional[Dict]:
    """
    Retrieve a document by ID.

    Args:
        document_id: The document identifier

    Returns:
        Document dictionary or None
    """
    return execute_query(
        "SELECT * FROM documents WHERE document_id = %s",
        (document_id,),
        fetch='one'
    )


def get_chunks_for_document(document_id: str) -> List[Dict]:
    """
    Retrieve all chunks for a document.

    Args:
        document_id: The document identifier

    Returns:
        List of chunk dictionaries
    """
    return execute_query(
        "SELECT * FROM chunks WHERE document_id = %s ORDER BY chunk_sequence",
        (document_id,),
        fetch='all'
    )


def get_chunks_without_embeddings(limit: int = 100) -> List[Dict]:
    """
    Get chunks that don't have embeddings yet.

    Args:
        limit: Maximum number of chunks to return

    Returns:
        List of chunk dictionaries
    """
    return execute_query(
        """
        SELECT chunk_id, chunk_text
        FROM chunks
        WHERE embedding IS NULL
        LIMIT %s
        """,
        (limit,),
        fetch='all'
    )


def semantic_search(
    query_embedding: List[float],
    limit: int = 10,
    document_id: str = None,
    min_quality: str = 'fair'
) -> List[Dict]:
    """
    Perform semantic search using vector similarity.

    Args:
        query_embedding: Query vector
        limit: Maximum results
        document_id: Optional filter by document
        min_quality: Minimum quality status to include ('excellent', 'good', 'fair', 'poor', 'unusable')
                    Set to None to include all documents regardless of quality.
                    Default 'fair' excludes 'poor' and 'unusable' documents.

    Returns:
        List of matching chunks with similarity scores
    """
    # Convert embedding to string format for pgvector
    embedding_str = '[' + ','.join(str(x) for x in query_embedding) + ']'

    # Build quality filter clause
    quality_levels = ['excellent', 'good', 'fair', 'poor', 'unusable']
    if min_quality and min_quality in quality_levels:
        min_idx = quality_levels.index(min_quality)
        allowed_qualities = quality_levels[:min_idx + 1]
        quality_filter = "AND d.quality_status IN ({})".format(
            ','.join(f"'{q}'" for q in allowed_qualities)
        )
    else:
        quality_filter = ""

    if document_id:
        query = f"""
            SELECT
                c.chunk_id,
                c.document_id,
                c.chunk_text,
                c.chunk_sequence,
                1 - (c.embedding <=> %s::vector) AS similarity
            FROM chunks c
            JOIN documents d ON c.document_id = d.document_id
            WHERE c.document_id = %s
              AND c.embedding IS NOT NULL
              {quality_filter}
            ORDER BY c.embedding <=> %s::vector
            LIMIT %s
        """
        params = (embedding_str, document_id, embedding_str, limit)
    else:
        query = f"""
            SELECT
                c.chunk_id,
                c.document_id,
                c.chunk_text,
                c.chunk_sequence,
                d.title,
                d.quality_status,
                1 - (c.embedding <=> %s::vector) AS similarity
            FROM chunks c
            JOIN documents d ON c.document_id = d.document_id
            WHERE c.embedding IS NOT NULL
              {quality_filter}
            ORDER BY c.embedding <=> %s::vector
            LIMIT %s
        """
        params = (embedding_str, embedding_str, limit)

    # Use direct connection to set IVFFlat probes for better recall
    # With 100 lists and ~1000 vectors, probes=10 gives good recall
    with get_dict_cursor() as (conn, cur):
        cur.execute("SET ivfflat.probes = 10")
        cur.execute(query, params)
        return cur.fetchall()


def keyword_search(
    query_text: str,
    limit: int = 10,
    document_id: str = None,
    min_quality: str = 'fair'
) -> List[Dict]:
    """
    Perform full-text keyword search.

    Args:
        query_text: Search query
        limit: Maximum results
        document_id: Optional filter by document
        min_quality: Minimum quality status to include ('excellent', 'good', 'fair', 'poor', 'unusable')
                    Set to None to include all documents regardless of quality.
                    Default 'fair' excludes 'poor' and 'unusable' documents.

    Returns:
        List of matching chunks with relevance scores
    """
    # Build quality filter clause
    quality_levels = ['excellent', 'good', 'fair', 'poor', 'unusable']
    if min_quality and min_quality in quality_levels:
        min_idx = quality_levels.index(min_quality)
        allowed_qualities = quality_levels[:min_idx + 1]
        quality_filter = "AND d.quality_status IN ({})".format(
            ','.join(f"'{q}'" for q in allowed_qualities)
        )
    else:
        quality_filter = ""

    if document_id:
        query = f"""
            SELECT
                c.chunk_id,
                c.document_id,
                c.chunk_text,
                c.chunk_sequence,
                ts_rank(c.chunk_text_tsv, plainto_tsquery('english', %s)) AS relevance
            FROM chunks c
            JOIN documents d ON c.document_id = d.document_id
            WHERE c.document_id = %s
              AND c.chunk_text_tsv @@ plainto_tsquery('english', %s)
              {quality_filter}
            ORDER BY relevance DESC
            LIMIT %s
        """
        params = (query_text, document_id, query_text, limit)
    else:
        query = f"""
            SELECT
                c.chunk_id,
                c.document_id,
                c.chunk_text,
                c.chunk_sequence,
                d.title,
                d.quality_status,
                ts_rank(c.chunk_text_tsv, plainto_tsquery('english', %s)) AS relevance
            FROM chunks c
            JOIN documents d ON c.document_id = d.document_id
            WHERE c.chunk_text_tsv @@ plainto_tsquery('english', %s)
              {quality_filter}
            ORDER BY relevance DESC
            LIMIT %s
        """
        params = (query_text, query_text, limit)

    return execute_query(query, params, fetch='all')


def get_database_stats() -> Dict:
    """
    Get summary statistics about the database.

    Returns:
        Dictionary with database statistics
    """
    stats = execute_query(
        "SELECT * FROM v_document_stats",
        fetch='one'
    )
    return dict(stats) if stats else {}


def log_search(query_text: str, search_type: str, results_count: int, response_time_ms: int) -> None:
    """
    Log a search query for analytics.
    """
    query = """
        INSERT INTO search_history (query_text, search_type, results_count, response_time_ms)
        VALUES (%s, %s, %s, %s)
    """
    with get_db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(query, (query_text, search_type, results_count, response_time_ms))


# =============================================================================
# GRAPHRAG RETRIEVAL (Optimization #2)
# =============================================================================

def find_concept_by_name(concept_name: str, fuzzy: bool = True) -> Optional[Dict]:
    """
    Find a concept by name, with optional fuzzy matching.

    Args:
        concept_name: Name to search for
        fuzzy: If True, use trigram similarity for fuzzy matching

    Returns:
        Concept dictionary or None
    """
    # First try exact match
    result = execute_query(
        "SELECT * FROM concepts WHERE LOWER(name) = LOWER(%s)",
        (concept_name,),
        fetch='one'
    )

    if result:
        return result

    # Try fuzzy match using pg_trgm
    if fuzzy:
        result = execute_query(
            """
            SELECT *, similarity(LOWER(name), LOWER(%s)) AS sim
            FROM concepts
            WHERE similarity(LOWER(name), LOWER(%s)) > 0.3
            ORDER BY sim DESC
            LIMIT 1
            """,
            (concept_name, concept_name),
            fetch='one'
        )
        return result

    return None


def get_related_concepts(
    concept_id: int,
    depth: int = 1,
    min_cooccurrence: int = 2
) -> List[Dict]:
    """
    Get concepts related to a given concept via co-occurrence.

    This enables multi-hop reasoning by finding concepts that appear
    together in the same chunks.

    Args:
        concept_id: The source concept ID
        depth: How many hops to traverse (1 = direct neighbors, 2 = neighbors of neighbors)
        min_cooccurrence: Minimum co-occurrence count for a relationship

    Returns:
        List of related concepts with relationship strength
    """
    visited = set()
    results = []
    current_level = {concept_id}

    for level in range(depth):
        next_level = set()

        for cid in current_level:
            if cid in visited:
                continue
            visited.add(cid)

            # Find co-occurring concepts
            related = execute_query(
                """
                SELECT
                    CASE WHEN cc1.concept_id = %s THEN cc2.concept_id ELSE cc1.concept_id END AS related_id,
                    c.name,
                    c.category,
                    COUNT(DISTINCT cc1.chunk_id) AS cooccurrence_count
                FROM chunk_concepts cc1
                JOIN chunk_concepts cc2 ON cc1.chunk_id = cc2.chunk_id
                JOIN concepts c ON c.concept_id = CASE
                    WHEN cc1.concept_id = %s THEN cc2.concept_id
                    ELSE cc1.concept_id
                END
                WHERE (cc1.concept_id = %s OR cc2.concept_id = %s)
                  AND cc1.concept_id != cc2.concept_id
                GROUP BY related_id, c.name, c.category
                HAVING COUNT(DISTINCT cc1.chunk_id) >= %s
                ORDER BY cooccurrence_count DESC
                """,
                (cid, cid, cid, cid, min_cooccurrence),
                fetch='all'
            )

            for r in related:
                related_id = r['related_id']
                if related_id not in visited:
                    next_level.add(related_id)
                    results.append({
                        'concept_id': related_id,
                        'name': r['name'],
                        'category': r['category'],
                        'cooccurrence_count': r['cooccurrence_count'],
                        'hop_distance': level + 1,
                        'source_concept_id': cid
                    })

        current_level = next_level
        if not current_level:
            break

    return results


def get_chunks_for_concepts(concept_ids: List[int], limit: int = 20) -> List[Dict]:
    """
    Get chunks that contain any of the specified concepts.

    Args:
        concept_ids: List of concept IDs to search for
        limit: Maximum number of chunks to return

    Returns:
        List of chunk dictionaries with concept metadata
    """
    if not concept_ids:
        return []

    placeholders = ','.join(['%s'] * len(concept_ids))

    results = execute_query(
        f"""
        SELECT DISTINCT ON (c.chunk_id)
            c.chunk_id,
            c.document_id,
            c.chunk_text,
            c.chunk_sequence,
            d.title,
            d.quality_status,
            array_agg(DISTINCT con.name) AS concept_names,
            SUM(cc.mention_count) AS total_mentions
        FROM chunks c
        JOIN chunk_concepts cc ON c.chunk_id = cc.chunk_id
        JOIN concepts con ON cc.concept_id = con.concept_id
        JOIN documents d ON c.document_id = d.document_id
        WHERE cc.concept_id IN ({placeholders})
          AND d.quality_status IN ('excellent', 'good', 'fair')
        GROUP BY c.chunk_id, c.document_id, c.chunk_text, c.chunk_sequence, d.title, d.quality_status
        ORDER BY c.chunk_id, total_mentions DESC
        LIMIT %s
        """,
        (*concept_ids, limit),
        fetch='all'
    )

    return results


def graphrag_search(
    query: str,
    limit: int = 10,
    hop_depth: int = 2,
    min_cooccurrence: int = 2,
    include_direct_search: bool = True
) -> Dict[str, Any]:
    """
    GraphRAG: Retrieval-Augmented Generation using knowledge graph traversal.

    This function enables multi-hop reasoning by:
    1. Identifying entities/concepts in the query
    2. Traversing the knowledge graph to find related concepts
    3. Retrieving chunks related to both direct and connected concepts
    4. Combining with traditional search results

    Args:
        query: The user's question
        limit: Maximum results to return
        hop_depth: How many relationship hops to traverse
        min_cooccurrence: Minimum co-occurrence for relationships
        include_direct_search: Also include standard hybrid search results

    Returns:
        Dictionary with:
        - 'results': Combined chunk results
        - 'concepts': Concepts found and traversed
        - 'graph_path': Relationship paths explored
    """
    from config import OPENAI_ENABLED

    response = {
        'results': [],
        'concepts': [],
        'graph_paths': [],
        'method': 'graphrag'
    }

    # Step 1: Extract potential concepts from query
    # Use simple word-based extraction; LLM extraction could enhance this
    query_words = [w.strip().lower() for w in query.split() if len(w.strip()) > 2]

    # Find matching concepts
    found_concepts = []
    for word in query_words:
        concept = find_concept_by_name(word, fuzzy=True)
        if concept:
            found_concepts.append(concept)

    # Also try multi-word phrases (2-3 words)
    words = query.split()
    for i in range(len(words)):
        for j in range(i + 2, min(i + 4, len(words) + 1)):
            phrase = ' '.join(words[i:j])
            concept = find_concept_by_name(phrase, fuzzy=True)
            if concept and concept not in found_concepts:
                found_concepts.append(concept)

    response['concepts'] = [
        {'id': c['concept_id'], 'name': c['name'], 'category': c.get('category')}
        for c in found_concepts
    ]

    # Step 2: Traverse graph to find related concepts
    all_concept_ids = set()
    for concept in found_concepts:
        all_concept_ids.add(concept['concept_id'])
        related = get_related_concepts(
            concept['concept_id'],
            depth=hop_depth,
            min_cooccurrence=min_cooccurrence
        )
        for r in related:
            all_concept_ids.add(r['concept_id'])
            response['graph_paths'].append({
                'from': concept['name'],
                'to': r['name'],
                'hops': r['hop_distance'],
                'strength': r['cooccurrence_count']
            })

    # Step 3: Get chunks for all concepts
    if all_concept_ids:
        graph_chunks = get_chunks_for_concepts(list(all_concept_ids), limit=limit * 2)

        # Add source info
        for chunk in graph_chunks:
            chunk['retrieval_method'] = 'graph'
            response['results'].append(chunk)

    # Step 4: Optionally combine with direct search
    if include_direct_search:
        try:
            direct_results = hybrid_search_with_rerank(
                query_text=query,
                limit=limit,
                use_rerank=True
            )
            for result in direct_results:
                # Avoid duplicates
                if not any(r['chunk_id'] == result['chunk_id'] for r in response['results']):
                    result['retrieval_method'] = 'hybrid'
                    response['results'].append(result)
        except Exception as e:
            logger.warning(f"Direct search failed in GraphRAG: {e}")

    # Limit final results
    response['results'] = response['results'][:limit]

    return response


def get_parent_chunk(chunk_id: str) -> Optional[Dict]:
    """
    Get the parent chunk for a child chunk (hierarchical chunking).

    Args:
        chunk_id: The child chunk ID

    Returns:
        Parent chunk dictionary or None
    """
    result = execute_query(
        """
        SELECT p.*
        FROM chunks c
        JOIN chunks p ON c.parent_chunk_id = p.chunk_id
        WHERE c.chunk_id = %s
        """,
        (chunk_id,),
        fetch='one'
    )
    return result


def get_chunk_with_context(
    chunk_id: str,
    context_chunks: int = 1,
    use_parent: bool = True
) -> Dict[str, Any]:
    """
    Get a chunk with surrounding context for better RAG generation.

    Args:
        chunk_id: The target chunk ID
        context_chunks: Number of preceding/following chunks to include
        use_parent: If hierarchical, prefer parent chunk over neighbors

    Returns:
        Dictionary with 'chunk', 'context_before', 'context_after', and optionally 'parent'
    """
    # Get the main chunk
    main_chunk = execute_query(
        "SELECT * FROM chunks WHERE chunk_id = %s",
        (chunk_id,),
        fetch='one'
    )

    if not main_chunk:
        return {}

    result = {'chunk': main_chunk}

    # Try to get parent chunk first
    if use_parent:
        parent = get_parent_chunk(chunk_id)
        if parent:
            result['parent'] = parent
            return result

    # Get surrounding chunks
    doc_id = main_chunk['document_id']
    seq = main_chunk['chunk_sequence']

    result['context_before'] = execute_query(
        """
        SELECT * FROM chunks
        WHERE document_id = %s
          AND chunk_sequence < %s
          AND chunk_sequence >= %s - %s
        ORDER BY chunk_sequence
        """,
        (doc_id, seq, seq, context_chunks),
        fetch='all'
    )

    result['context_after'] = execute_query(
        """
        SELECT * FROM chunks
        WHERE document_id = %s
          AND chunk_sequence > %s
          AND chunk_sequence <= %s + %s
        ORDER BY chunk_sequence
        """,
        (doc_id, seq, seq, context_chunks),
        fetch='all'
    )

    return result
