#!/usr/bin/env python3
"""
Document Chunking Pipeline for Research Development Framework v2.1

This script processes ingested documents and creates searchable chunks:
1. Load documents that need chunking
2. Split text into optimal chunks with overlap
3. OPTIONAL: Create hierarchical parent-child chunks (v2.0)
4. OPTIONAL: Use semantic boundary detection (NEW in v2.1)
5. Store chunks in database with full-text search vectors
6. Queue chunks for embedding generation

Chunking Modes:
- Standard: Fixed-size chunks with overlap (default)
- Hierarchical: Parent chunks (2000 tokens) + child chunks (200 tokens)
  - Parent chunks provide context for LLM generation
  - Child chunks enable precise search/quote finding
- Semantic: Uses embedding similarity to detect topic boundaries (NEW in v2.1)
  - Chunks break at semantic shift points rather than arbitrary token counts
  - Preserves topical coherence within each chunk

Usage:
    python chunk_documents.py                         # Standard chunking
    python chunk_documents.py --hierarchical          # Hierarchical chunking
    python chunk_documents.py --semantic              # Semantic boundary chunking
    python chunk_documents.py --document DOC_001      # Chunk specific document
    python chunk_documents.py --rechunk               # Force rechunk all documents
"""

import re
import logging
from pathlib import Path
from typing import List, Dict, Any, Generator
import argparse

try:
    import tiktoken
    HAS_TIKTOKEN = True
except ImportError:
    HAS_TIKTOKEN = False
    print("Warning: tiktoken not installed. Using approximate token counting.")

from config import CHUNK_CONFIG, PROCESSING_CONFIG, LOGGING_CONFIG
from db_utils import (
    get_db_connection,
    execute_query,
    batch_insert_chunks,
    update_document_status
)

# Setup logging
logging.basicConfig(
    level=getattr(logging, LOGGING_CONFIG['level']),
    format=LOGGING_CONFIG['format']
)
logger = logging.getLogger(__name__)


class TextChunker:
    """Handles intelligent text chunking with overlap."""

    def __init__(self):
        self.target_tokens = CHUNK_CONFIG['target_tokens']
        self.overlap_tokens = CHUNK_CONFIG['overlap_tokens']
        self.min_tokens = CHUNK_CONFIG['min_tokens']
        self.max_tokens = CHUNK_CONFIG['max_tokens']

        if HAS_TIKTOKEN:
            self.encoder = tiktoken.get_encoding(CHUNK_CONFIG['encoding'])
        else:
            self.encoder = None

    def count_tokens(self, text: str) -> int:
        """Count tokens in text."""
        if self.encoder:
            return len(self.encoder.encode(text))
        # Approximate: ~4 characters per token
        return len(text) // 4

    def split_into_sentences(self, text: str) -> List[str]:
        """Split text into sentences while preserving structure."""
        # Handle common abbreviations
        text = re.sub(r'(?<=[A-Z])\.(?=[A-Z])', '@@DOT@@', text)  # Initials
        text = re.sub(r'(Mr|Mrs|Ms|Dr|Prof|Jr|Sr|vs|etc|i\.e|e\.g)\.',
                      r'\1@@DOT@@', text, flags=re.IGNORECASE)

        # Split on sentence boundaries
        sentences = re.split(r'(?<=[.!?])\s+', text)

        # Restore dots
        sentences = [s.replace('@@DOT@@', '.') for s in sentences]

        return [s.strip() for s in sentences if s.strip()]

    def split_into_paragraphs(self, text: str) -> List[str]:
        """Split text into paragraphs."""
        paragraphs = re.split(r'\n\s*\n', text)
        return [p.strip() for p in paragraphs if p.strip()]

    def get_pages_for_chunk(self, chunk_start: int, chunk_end: int, page_map: list) -> Dict[str, int]:
        """
        Get page numbers for a chunk based on character positions.

        Args:
            chunk_start: Start character position in full text
            chunk_end: End character position in full text
            page_map: List of page mappings from PDF extraction

        Returns:
            Dict with 'page_start' and 'page_end' keys
        """
        if not page_map:
            return {'page_start': None, 'page_end': None}

        page_start = None
        page_end = None

        for pm in page_map:
            # Find start page
            if page_start is None and pm['start_char'] <= chunk_start <= pm['end_char']:
                page_start = pm['page']
            # Find end page
            if pm['start_char'] <= chunk_end <= pm['end_char']:
                page_end = pm['page']

        # If chunk starts before any mapped page, use first page
        if page_start is None and page_map:
            page_start = page_map[0]['page']

        # If chunk ends after all mapped pages, use last page
        if page_end is None and page_map:
            page_end = page_map[-1]['page']

        return {'page_start': page_start, 'page_end': page_end}

    def chunk_text(self, text: str, document_id: str, page_map: list = None) -> Generator[Dict[str, Any], None, None]:
        """
        Split text into chunks with overlap and page tracking.

        Args:
            text: Full document text
            document_id: Document identifier
            page_map: Optional list of page mappings from PDF extraction
                      [{'page': 1, 'start_char': 0, 'end_char': 1000}, ...]

        Yields chunk dictionaries ready for database insertion.
        """
        original_text = text  # Keep original for position tracking
        front_matter_offset = 0

        # Remove markdown front matter if present
        if text.startswith('---'):
            end_match = re.search(r'\n---\n', text[3:])
            if end_match:
                front_matter_offset = end_match.end() + 3
                text = text[front_matter_offset:]

        paragraphs = self.split_into_paragraphs(text)

        current_chunk = []
        current_chunk_positions = []  # Track (text, start_pos, end_pos) for each item
        current_tokens = 0
        chunk_sequence = 0
        overlap_text = []
        overlap_positions = []
        overlap_tokens = 0

        # Track position in the processed text
        current_pos = 0
        for paragraph in paragraphs:
            # Find paragraph position in text
            para_start = text.find(paragraph, current_pos)
            if para_start == -1:
                para_start = current_pos
            para_end = para_start + len(paragraph)
            current_pos = para_end

            para_tokens = self.count_tokens(paragraph)

            # If single paragraph exceeds max, split by sentences
            if para_tokens > self.max_tokens:
                sentences = self.split_into_sentences(paragraph)
                sent_pos = para_start
                for sentence in sentences:
                    sent_start = paragraph.find(sentence, sent_pos - para_start) + para_start
                    sent_end = sent_start + len(sentence)
                    sent_tokens = self.count_tokens(sentence)

                    if current_tokens + sent_tokens > self.target_tokens:
                        # Yield current chunk
                        if current_tokens >= self.min_tokens:
                            chunk_sequence += 1
                            chunk_text = ' '.join([t for t, s, e in current_chunk_positions])

                            # Calculate page numbers from positions
                            if current_chunk_positions:
                                chunk_start = current_chunk_positions[0][1] + front_matter_offset
                                chunk_end = current_chunk_positions[-1][2] + front_matter_offset
                            else:
                                chunk_start = chunk_end = 0

                            pages = self.get_pages_for_chunk(chunk_start, chunk_end, page_map)

                            yield {
                                'chunk_id': f"{document_id}_C{chunk_sequence:04d}",
                                'document_id': document_id,
                                'chunk_sequence': chunk_sequence,
                                'chunk_text': chunk_text,
                                'chunk_tokens': self.count_tokens(chunk_text),
                                'page_start': pages['page_start'],
                                'page_end': pages['page_end']
                            }

                            # Prepare overlap for next chunk
                            overlap_positions = current_chunk_positions[-2:] if len(current_chunk_positions) >= 2 else current_chunk_positions
                            overlap_tokens = sum(self.count_tokens(t) for t, s, e in overlap_positions)

                        # Start new chunk with overlap
                        current_chunk_positions = overlap_positions + [(sentence, sent_start, sent_end)]
                        current_tokens = overlap_tokens + sent_tokens
                    else:
                        current_chunk_positions.append((sentence, sent_start, sent_end))
                        current_tokens += sent_tokens

                    sent_pos = sent_end
            else:
                # Check if adding paragraph exceeds target
                if current_tokens + para_tokens > self.target_tokens:
                    # Yield current chunk
                    if current_tokens >= self.min_tokens:
                        chunk_sequence += 1
                        chunk_text = '\n\n'.join([t for t, s, e in current_chunk_positions])

                        # Calculate page numbers from positions
                        if current_chunk_positions:
                            chunk_start = current_chunk_positions[0][1] + front_matter_offset
                            chunk_end = current_chunk_positions[-1][2] + front_matter_offset
                        else:
                            chunk_start = chunk_end = 0

                        pages = self.get_pages_for_chunk(chunk_start, chunk_end, page_map)

                        yield {
                            'chunk_id': f"{document_id}_C{chunk_sequence:04d}",
                            'document_id': document_id,
                            'chunk_sequence': chunk_sequence,
                            'chunk_text': chunk_text,
                            'chunk_tokens': self.count_tokens(chunk_text),
                            'page_start': pages['page_start'],
                            'page_end': pages['page_end']
                        }

                        # Prepare overlap
                        overlap_positions = current_chunk_positions[-1:] if current_chunk_positions else []
                        overlap_tokens = sum(self.count_tokens(t) for t, s, e in overlap_positions)

                    # Start new chunk with overlap
                    current_chunk_positions = overlap_positions + [(paragraph, para_start, para_end)]
                    current_tokens = overlap_tokens + para_tokens
                else:
                    current_chunk_positions.append((paragraph, para_start, para_end))
                    current_tokens += para_tokens

        # Yield final chunk
        if current_chunk_positions and current_tokens >= self.min_tokens:
            chunk_sequence += 1
            chunk_text = '\n\n'.join([t for t, s, e in current_chunk_positions])

            # Calculate page numbers from positions
            chunk_start = current_chunk_positions[0][1] + front_matter_offset
            chunk_end = current_chunk_positions[-1][2] + front_matter_offset
            pages = self.get_pages_for_chunk(chunk_start, chunk_end, page_map)

            yield {
                'chunk_id': f"{document_id}_C{chunk_sequence:04d}",
                'document_id': document_id,
                'chunk_sequence': chunk_sequence,
                'chunk_text': chunk_text,
                'chunk_tokens': self.count_tokens(chunk_text),
                'page_start': pages['page_start'],
                'page_end': pages['page_end']
            }


class AdaptiveChunker:
    """
    Hierarchical chunking that creates parent-child relationships.

    This chunker creates two levels of chunks:
    - Parent chunks (~2000 tokens): Full context for LLM generation
    - Child chunks (~200 tokens): Precise search targets

    When a child chunk is found via search, the parent provides
    surrounding context for better understanding and RAG generation.
    """

    def __init__(
        self,
        parent_tokens: int = 2000,
        child_tokens: int = 200,
        child_overlap: int = 50
    ):
        """
        Initialize the adaptive chunker.

        Args:
            parent_tokens: Target size for parent chunks
            child_tokens: Target size for child chunks
            child_overlap: Overlap between child chunks
        """
        self.parent_tokens = parent_tokens
        self.child_tokens = child_tokens
        self.child_overlap = child_overlap

        if HAS_TIKTOKEN:
            self.encoder = tiktoken.get_encoding(CHUNK_CONFIG['encoding'])
        else:
            self.encoder = None

    def count_tokens(self, text: str) -> int:
        """Count tokens in text."""
        if self.encoder:
            return len(self.encoder.encode(text))
        return len(text) // 4

    def detect_sections(self, text: str) -> List[Dict[str, Any]]:
        """
        Detect natural section boundaries in text.

        Returns list of {title, start, end, level} dictionaries.
        """
        sections = []
        lines = text.split('\n')
        current_pos = 0

        # Patterns for section detection
        patterns = [
            (r'^(#{1,4})\s+(.+)$', lambda m: len(m.group(1))),  # Markdown headers
            (r'^(Chapter|CHAPTER)\s+(\d+|[IVX]+)[:\.]?\s*(.*)$', lambda m: 1),
            (r'^(Part|PART)\s+(\d+|[IVX]+)[:\.]?\s*(.*)$', lambda m: 1),
            (r'^(Section|SECTION)\s+(\d+)[:\.]?\s*(.*)$', lambda m: 2),
            (r'^(\d+\.)\s+([A-Z].+)$', lambda m: 2),  # Numbered sections
        ]

        for i, line in enumerate(lines):
            for pattern, level_fn in patterns:
                match = re.match(pattern, line.strip())
                if match:
                    sections.append({
                        'title': line.strip()[:100],
                        'start': current_pos,
                        'level': level_fn(match),
                        'line': i
                    })
                    break
            current_pos += len(line) + 1

        # Calculate end positions
        for i, section in enumerate(sections):
            if i + 1 < len(sections):
                section['end'] = sections[i + 1]['start']
            else:
                section['end'] = len(text)

        return sections

    def split_into_sentences(self, text: str) -> List[str]:
        """Split text into sentences, handling abbreviations."""
        # Handle common abbreviations
        text = re.sub(r'(?<=[A-Z])\.(?=[A-Z])', '@@DOT@@', text)
        text = re.sub(r'(Mr|Mrs|Ms|Dr|Prof|Jr|Sr|vs|etc|i\.e|e\.g)\.',
                      r'\1@@DOT@@', text, flags=re.IGNORECASE)

        # Split on sentence boundaries
        sentences = re.split(r'(?<=[.!?])\s+', text)

        # Restore dots and clean
        sentences = [s.replace('@@DOT@@', '.').strip() for s in sentences if s.strip()]
        return sentences

    def chunk_hierarchical(
        self,
        text: str,
        document_id: str
    ) -> List[Dict[str, Any]]:
        """
        Create hierarchical parent-child chunks.

        Returns list of chunk dictionaries with:
        - chunk_id, document_id, chunk_sequence
        - chunk_text, chunk_tokens
        - chunk_level: 'parent' or 'child'
        - parent_chunk_id: reference for child chunks
        - section_title: from detected sections
        """
        # Remove markdown front matter
        if text.startswith('---'):
            end_match = re.search(r'\n---\n', text[3:])
            if end_match:
                text = text[end_match.end() + 3:]

        chunks = []
        chunk_sequence = 0

        # Detect sections
        sections = self.detect_sections(text)

        if not sections:
            # No sections detected - treat entire text as one section
            sections = [{'title': None, 'start': 0, 'end': len(text), 'level': 1}]

        for section in sections:
            section_text = text[section['start']:section['end']].strip()
            if not section_text:
                continue

            section_tokens = self.count_tokens(section_text)
            section_title = section.get('title')

            # Create parent chunks for this section
            parent_chunks = self._create_parent_chunks(
                section_text, document_id, chunk_sequence, section_title
            )

            for parent in parent_chunks:
                chunks.append(parent)
                chunk_sequence += 1
                parent_seq = chunk_sequence

                # Create child chunks within this parent
                child_chunks = self._create_child_chunks(
                    parent['chunk_text'],
                    document_id,
                    parent['chunk_id'],
                    chunk_sequence
                )

                for child in child_chunks:
                    chunk_sequence += 1
                    child['chunk_sequence'] = chunk_sequence
                    chunks.append(child)

        return chunks

    def _create_parent_chunks(
        self,
        text: str,
        doc_id: str,
        start_sequence: int,
        section_title: str = None
    ) -> List[Dict[str, Any]]:
        """Create parent-level chunks from text."""
        parents = []
        paragraphs = re.split(r'\n\s*\n', text)
        paragraphs = [p.strip() for p in paragraphs if p.strip()]

        current_text = []
        current_tokens = 0
        sequence = start_sequence

        for para in paragraphs:
            para_tokens = self.count_tokens(para)

            if current_tokens + para_tokens > self.parent_tokens and current_text:
                # Save current parent chunk
                chunk_text = '\n\n'.join(current_text)
                sequence += 1
                parents.append({
                    'chunk_id': f"{doc_id}_P{sequence:04d}",
                    'document_id': doc_id,
                    'chunk_sequence': sequence,
                    'chunk_text': chunk_text,
                    'chunk_tokens': self.count_tokens(chunk_text),
                    'chunk_level': 'parent',
                    'parent_chunk_id': None,
                    'section_title': section_title
                })

                # Start new chunk
                current_text = [para]
                current_tokens = para_tokens
            else:
                current_text.append(para)
                current_tokens += para_tokens

        # Save final parent chunk
        if current_text:
            chunk_text = '\n\n'.join(current_text)
            sequence += 1
            parents.append({
                'chunk_id': f"{doc_id}_P{sequence:04d}",
                'document_id': doc_id,
                'chunk_sequence': sequence,
                'chunk_text': chunk_text,
                'chunk_tokens': self.count_tokens(chunk_text),
                'chunk_level': 'parent',
                'parent_chunk_id': None,
                'section_title': section_title
            })

        return parents

    def _create_child_chunks(
        self,
        text: str,
        doc_id: str,
        parent_id: str,
        start_sequence: int
    ) -> List[Dict[str, Any]]:
        """Create child chunks within a parent."""
        children = []
        sentences = self.split_into_sentences(text)

        current_text = []
        current_tokens = 0
        sequence = start_sequence

        for sentence in sentences:
            sent_tokens = self.count_tokens(sentence)

            if current_tokens + sent_tokens > self.child_tokens and current_text:
                # Save current child chunk
                chunk_text = ' '.join(current_text)
                sequence += 1
                children.append({
                    'chunk_id': f"{doc_id}_C{sequence:04d}",
                    'document_id': doc_id,
                    'chunk_sequence': sequence,
                    'chunk_text': chunk_text,
                    'chunk_tokens': self.count_tokens(chunk_text),
                    'chunk_level': 'child',
                    'parent_chunk_id': parent_id,
                    'section_title': None
                })

                # Overlap: keep last 1-2 sentences
                overlap = current_text[-2:] if len(current_text) >= 2 else current_text[-1:]
                current_text = overlap + [sentence]
                current_tokens = sum(self.count_tokens(s) for s in current_text)
            else:
                current_text.append(sentence)
                current_tokens += sent_tokens

        # Save final child chunk
        if current_text:
            chunk_text = ' '.join(current_text)
            sequence += 1
            children.append({
                'chunk_id': f"{doc_id}_C{sequence:04d}",
                'document_id': doc_id,
                'chunk_sequence': sequence,
                'chunk_text': chunk_text,
                'chunk_tokens': self.count_tokens(chunk_text),
                'chunk_level': 'child',
                'parent_chunk_id': parent_id,
                'section_title': None
            })

        return children


class SemanticChunker:
    """
    Semantic Chunking using embedding-based boundary detection (Optimization #3).

    Instead of breaking text at arbitrary token counts, this chunker:
    1. Splits text into sentences
    2. Computes embeddings for sentence groups
    3. Detects semantic shifts by measuring cosine similarity drop
    4. Creates chunk boundaries at topic transition points

    This preserves topical coherence within chunks, improving both
    search precision and LLM context quality.
    """

    def __init__(
        self,
        min_tokens: int = 100,
        max_tokens: int = 1000,
        target_tokens: int = 500,
        similarity_threshold: float = 0.5,
        window_size: int = 3
    ):
        """
        Initialize the semantic chunker.

        Args:
            min_tokens: Minimum tokens per chunk
            max_tokens: Maximum tokens per chunk
            target_tokens: Ideal chunk size
            similarity_threshold: Cosine similarity threshold for boundary detection
                                  Lower = more sensitive to topic changes
            window_size: Number of sentences to group for embedding comparison
        """
        self.min_tokens = min_tokens
        self.max_tokens = max_tokens
        self.target_tokens = target_tokens
        self.similarity_threshold = similarity_threshold
        self.window_size = window_size

        if HAS_TIKTOKEN:
            self.encoder = tiktoken.get_encoding(CHUNK_CONFIG['encoding'])
        else:
            self.encoder = None

        # Lazy-load embedding model
        self._embedding_model = None
        self._embedding_available = None

    def _get_embedding_model(self):
        """Lazy-load the sentence embedding model."""
        if self._embedding_available is False:
            return None

        if self._embedding_model is None:
            try:
                from sentence_transformers import SentenceTransformer
                # Use a lightweight model for fast local embeddings
                self._embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
                self._embedding_available = True
                logger.info("Loaded sentence-transformers model for semantic chunking")
            except ImportError:
                logger.warning("sentence-transformers not installed. Falling back to standard chunking.")
                logger.warning("Install with: pip install sentence-transformers")
                self._embedding_available = False
                return None
            except Exception as e:
                logger.warning(f"Failed to load embedding model: {e}")
                self._embedding_available = False
                return None

        return self._embedding_model

    def count_tokens(self, text: str) -> int:
        """Count tokens in text."""
        if self.encoder:
            return len(self.encoder.encode(text))
        return len(text) // 4

    def split_into_sentences(self, text: str) -> List[str]:
        """Split text into sentences, handling abbreviations."""
        # Handle common abbreviations
        text = re.sub(r'(?<=[A-Z])\.(?=[A-Z])', '@@DOT@@', text)
        text = re.sub(r'(Mr|Mrs|Ms|Dr|Prof|Jr|Sr|vs|etc|i\.e|e\.g)\.',
                      r'\1@@DOT@@', text, flags=re.IGNORECASE)

        # Split on sentence boundaries
        sentences = re.split(r'(?<=[.!?])\s+', text)

        # Restore dots and clean
        sentences = [s.replace('@@DOT@@', '.').strip() for s in sentences if s.strip()]
        return sentences

    def compute_similarity(self, embedding1, embedding2) -> float:
        """Compute cosine similarity between two embeddings."""
        import numpy as np
        dot_product = np.dot(embedding1, embedding2)
        norm1 = np.linalg.norm(embedding1)
        norm2 = np.linalg.norm(embedding2)
        if norm1 == 0 or norm2 == 0:
            return 0.0
        return dot_product / (norm1 * norm2)

    def find_semantic_boundaries(self, sentences: List[str]) -> List[int]:
        """
        Find indices where semantic shifts occur.

        Returns list of sentence indices that should start new chunks.
        """
        model = self._get_embedding_model()
        if model is None or len(sentences) < self.window_size * 2:
            return []

        import numpy as np

        # Create sliding windows of sentences
        windows = []
        for i in range(len(sentences) - self.window_size + 1):
            window_text = ' '.join(sentences[i:i + self.window_size])
            windows.append(window_text)

        if len(windows) < 2:
            return []

        # Compute embeddings for all windows
        try:
            embeddings = model.encode(windows, show_progress_bar=False)
        except Exception as e:
            logger.warning(f"Embedding computation failed: {e}")
            return []

        # Find similarity drops between consecutive windows
        boundaries = []
        similarities = []

        for i in range(len(embeddings) - 1):
            sim = self.compute_similarity(embeddings[i], embeddings[i + 1])
            similarities.append(sim)

        if not similarities:
            return []

        # Find significant drops (below threshold or below local average)
        mean_sim = np.mean(similarities)
        std_sim = np.std(similarities)
        threshold = min(self.similarity_threshold, mean_sim - std_sim)

        for i, sim in enumerate(similarities):
            if sim < threshold:
                # Boundary at the sentence after the window
                boundary_idx = i + self.window_size
                if boundary_idx < len(sentences):
                    boundaries.append(boundary_idx)

        return boundaries

    def chunk_semantic(
        self,
        text: str,
        document_id: str,
        page_map: list = None
    ) -> List[Dict[str, Any]]:
        """
        Create chunks using semantic boundary detection.

        Falls back to standard chunking if embeddings unavailable.

        Args:
            text: Full document text
            document_id: Document identifier
            page_map: Optional page position mapping

        Returns:
            List of chunk dictionaries
        """
        # Remove markdown front matter
        front_matter_offset = 0
        if text.startswith('---'):
            end_match = re.search(r'\n---\n', text[3:])
            if end_match:
                front_matter_offset = end_match.end() + 3
                text = text[front_matter_offset:]

        sentences = self.split_into_sentences(text)

        if not sentences:
            return []

        # Find semantic boundaries
        boundaries = self.find_semantic_boundaries(sentences)

        # Add start and end
        all_boundaries = [0] + sorted(set(boundaries)) + [len(sentences)]

        chunks = []
        chunk_sequence = 0

        i = 0
        while i < len(all_boundaries) - 1:
            start_idx = all_boundaries[i]
            end_idx = all_boundaries[i + 1]

            # Gather sentences for this chunk
            chunk_sentences = sentences[start_idx:end_idx]
            chunk_text = ' '.join(chunk_sentences)
            chunk_tokens = self.count_tokens(chunk_text)

            # Handle chunks that are too small
            if chunk_tokens < self.min_tokens and i + 2 < len(all_boundaries):
                # Merge with next segment
                i += 1
                continue

            # Handle chunks that are too large - split at paragraph or halfway
            if chunk_tokens > self.max_tokens:
                # Find a natural break point
                mid_point = len(chunk_sentences) // 2
                first_half = ' '.join(chunk_sentences[:mid_point])
                second_half = ' '.join(chunk_sentences[mid_point:])

                for part in [first_half, second_half]:
                    if part.strip():
                        chunk_sequence += 1
                        chunks.append({
                            'chunk_id': f"{document_id}_C{chunk_sequence:04d}",
                            'document_id': document_id,
                            'chunk_sequence': chunk_sequence,
                            'chunk_text': part.strip(),
                            'chunk_tokens': self.count_tokens(part),
                            'chunk_method': 'semantic',
                            'page_start': None,
                            'page_end': None
                        })
            else:
                chunk_sequence += 1
                chunks.append({
                    'chunk_id': f"{document_id}_C{chunk_sequence:04d}",
                    'document_id': document_id,
                    'chunk_sequence': chunk_sequence,
                    'chunk_text': chunk_text.strip(),
                    'chunk_tokens': chunk_tokens,
                    'chunk_method': 'semantic',
                    'page_start': None,
                    'page_end': None
                })

            i += 1

        # If semantic chunking produced no results, fall back to standard
        if not chunks:
            logger.warning(f"Semantic chunking produced no chunks for {document_id}, falling back to standard")
            standard_chunker = TextChunker()
            return list(standard_chunker.chunk_text(text, document_id, page_map))

        return chunks


class DocumentChunker:
    """Orchestrates document chunking process."""

    def __init__(self, rechunk: bool = False, hierarchical: bool = False, semantic: bool = False):
        """
        Initialize the document chunker.

        Args:
            rechunk: Force rechunk all documents
            hierarchical: Use hierarchical parent-child chunking
            semantic: Use semantic boundary detection for chunking
        """
        self.rechunk = rechunk
        self.hierarchical = hierarchical
        self.semantic = semantic
        self.mode = 'standard'

        if semantic:
            self.chunker = SemanticChunker()
            self.mode = 'semantic'
            logger.info("Using semantic chunking mode (embedding-based boundaries)")
        elif hierarchical:
            self.chunker = AdaptiveChunker()
            self.mode = 'hierarchical'
            logger.info("Using hierarchical chunking mode (parent-child)")
        else:
            self.chunker = TextChunker()
            logger.info("Using standard chunking mode")

        self.stats = {
            'documents_processed': 0,
            'chunks_created': 0,
            'parent_chunks': 0,
            'child_chunks': 0,
            'errors': 0
        }

    def get_documents_to_chunk(self, document_id: str = None) -> List[Dict]:
        """Get list of documents that need chunking."""
        if document_id:
            return execute_query(
                "SELECT * FROM documents WHERE document_id = %s",
                (document_id,),
                fetch='all'
            )

        if self.rechunk:
            return execute_query(
                "SELECT * FROM documents WHERE processing_status IN ('ingested', 'completed')",
                fetch='all'
            )

        return execute_query(
            "SELECT * FROM documents WHERE processing_status = 'ingested'",
            fetch='all'
        )

    def load_document_content(self, document: Dict) -> str:
        """Load content from document's markdown file."""
        file_path = document.get('file_path')
        if not file_path:
            raise ValueError(f"No file path for document {document['document_id']}")

        path = Path(file_path)
        if not path.exists():
            raise FileNotFoundError(f"File not found: {file_path}")

        with open(path, 'r', encoding='utf-8') as f:
            return f.read()

    def delete_existing_chunks(self, document_id: str) -> int:
        """Delete existing chunks for a document."""
        with get_db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute(
                    "DELETE FROM chunks WHERE document_id = %s",
                    (document_id,)
                )
                return cur.rowcount

    def chunk_document(self, document: Dict) -> int:
        """
        Chunk a single document and store in database.

        Returns number of chunks created.
        """
        document_id = document['document_id']
        logger.info(f"Chunking document: {document_id}")

        try:
            # Update status
            update_document_status(document_id, 'chunking')

            # Load content
            content = self.load_document_content(document)

            # Delete existing chunks if rechunking
            if self.rechunk:
                deleted = self.delete_existing_chunks(document_id)
                if deleted:
                    logger.info(f"Deleted {deleted} existing chunks")

            # Generate chunks based on mode
            if self.semantic:
                chunks = self.chunker.chunk_semantic(content, document_id)
            elif self.hierarchical:
                chunks = self.chunker.chunk_hierarchical(content, document_id)
            else:
                chunks = list(self.chunker.chunk_text(content, document_id))

            if not chunks:
                logger.warning(f"No chunks generated for {document_id}")
                update_document_status(document_id, 'failed', 'No chunks generated')
                return 0

            # Count chunk types for stats
            parent_count = sum(1 for c in chunks if c.get('chunk_level') == 'parent')
            child_count = sum(1 for c in chunks if c.get('chunk_level') == 'child')

            # Batch insert chunks
            chunk_count = batch_insert_chunks(chunks)

            # Update full-text search vectors
            with get_db_connection() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        UPDATE chunks
                        SET chunk_text_tsv = to_tsvector('english', chunk_text)
                        WHERE document_id = %s
                    """, (document_id,))

            # Update document status
            update_document_status(document_id, 'chunked')

            if self.hierarchical:
                logger.info(f"Created {chunk_count} chunks ({parent_count} parents, {child_count} children) for {document_id}")
                self.stats['parent_chunks'] += parent_count
                self.stats['child_chunks'] += child_count
            else:
                logger.info(f"Created {chunk_count} chunks for {document_id}")

            return chunk_count

        except Exception as e:
            logger.error(f"Error chunking {document_id}: {e}")
            update_document_status(document_id, 'failed', str(e))
            raise

    def process_all(self, document_id: str = None) -> Dict[str, int]:
        """Process all documents that need chunking."""
        documents = self.get_documents_to_chunk(document_id)

        if not documents:
            logger.info("No documents to chunk")
            return self.stats

        logger.info(f"Found {len(documents)} documents to chunk")

        for doc in documents:
            try:
                chunks_created = self.chunk_document(doc)
                self.stats['documents_processed'] += 1
                self.stats['chunks_created'] += chunks_created
            except Exception as e:
                logger.error(f"Failed to chunk {doc['document_id']}: {e}")
                self.stats['errors'] += 1

        return self.stats


def main():
    parser = argparse.ArgumentParser(
        description='Chunk documents for the Research Development Framework'
    )
    parser.add_argument(
        '--document',
        type=str,
        help='Process a specific document by ID'
    )
    parser.add_argument(
        '--rechunk',
        action='store_true',
        help='Force rechunk all documents (deletes existing chunks)'
    )
    parser.add_argument(
        '--hierarchical',
        action='store_true',
        help='Use hierarchical parent-child chunking (better for RAG)'
    )
    parser.add_argument(
        '--semantic',
        action='store_true',
        help='Use semantic boundary detection for chunking (requires sentence-transformers)'
    )
    parser.add_argument(
        '--similarity-threshold',
        type=float,
        default=0.5,
        help='Similarity threshold for semantic chunking boundary detection (default: 0.5)'
    )
    parser.add_argument(
        '--parent-tokens',
        type=int,
        default=2000,
        help='Target tokens for parent chunks (default: 2000)'
    )
    parser.add_argument(
        '--child-tokens',
        type=int,
        default=200,
        help='Target tokens for child chunks (default: 200)'
    )
    parser.add_argument(
        '--yes', '-y',
        action='store_true',
        help='Skip confirmation prompts (for automation)'
    )

    args = parser.parse_args()

    # Warning for rechunk operation - will invalidate embeddings
    if args.rechunk:
        # Check if embeddings exist
        from db_utils import get_db_connection
        with get_db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute("""
                    SELECT COUNT(*) FROM chunks
                    WHERE embedding IS NOT NULL
                """)
                embedded_count = cur.fetchone()[0]

                if embedded_count > 0:
                    # Calculate approximate re-embedding cost
                    cur.execute("SELECT SUM(chunk_tokens) FROM chunks WHERE embedding IS NOT NULL")
                    total_tokens = cur.fetchone()[0] or 0
                    estimated_cost = (total_tokens / 1_000_000) * 0.02  # $0.02/1M tokens

                    print("\n" + "!" * 60)
                    print("WARNING: Re-chunking will INVALIDATE existing embeddings!")
                    print("!" * 60)
                    print(f"Embedded chunks to be deleted: {embedded_count:,}")
                    print(f"Estimated re-embedding cost:   ${estimated_cost:.4f}")
                    print("")
                    print("You will need to run generate_embeddings.py again after")
                    print("re-chunking to restore semantic search capability.")
                    print("!" * 60)

                    if not args.yes:
                        response = input("\nProceed with re-chunking? [y/N]: ")
                        if response.lower() != 'y':
                            print("Aborted.")
                            return
                    else:
                        print("\n--yes flag provided, proceeding...")

    # Validate mutually exclusive modes
    if args.hierarchical and args.semantic:
        print("ERROR: Cannot use both --hierarchical and --semantic modes")
        return

    processor = DocumentChunker(
        rechunk=args.rechunk,
        hierarchical=args.hierarchical,
        semantic=args.semantic
    )

    # Print configuration
    print("\n" + "=" * 50)
    print("CHUNKING CONFIGURATION")
    print("=" * 50)
    if args.semantic:
        mode = "Semantic (embedding-based boundaries)"
    elif args.hierarchical:
        mode = "Hierarchical (parent-child)"
    else:
        mode = "Standard"
    print(f"Mode:          {mode}")
    print(f"Rechunk:       {'Yes' if args.rechunk else 'No'}")
    if args.hierarchical:
        print(f"Parent tokens: {args.parent_tokens}")
        print(f"Child tokens:  {args.child_tokens}")
    if args.semantic:
        print(f"Similarity threshold: {args.similarity_threshold}")
    print("=" * 50 + "\n")

    stats = processor.process_all(document_id=args.document)

    # Print summary
    print("\n" + "=" * 50)
    print("CHUNKING SUMMARY")
    print("=" * 50)
    print(f"Documents Processed: {stats['documents_processed']}")
    print(f"Chunks Created:      {stats['chunks_created']}")
    if args.hierarchical:
        print(f"  - Parent Chunks:   {stats['parent_chunks']}")
        print(f"  - Child Chunks:    {stats['child_chunks']}")
    print(f"Errors:              {stats['errors']}")
    print("=" * 50)


if __name__ == '__main__':
    main()
