#!/usr/bin/env python3
"""
Metadata Editor for Research Development Framework

Edit document metadata without re-ingesting or losing embeddings.
Solves the "linear pipeline trap" by allowing in-place metadata corrections.

Usage:
    # Show current metadata
    python edit_metadata.py DOC_ID --show

    # Edit single field
    python edit_metadata.py DOC_ID --set title "New Title"

    # Edit multiple fields
    python edit_metadata.py DOC_ID --set title "New Title" --set author "Author Name"

    # Interactive edit mode
    python edit_metadata.py DOC_ID --interactive

    # Batch edit from JSON file
    python edit_metadata.py --batch edits.json

    # List documents needing review
    python edit_metadata.py --list-review
"""

import sys
import os
import json
import argparse
from datetime import datetime
from typing import Dict, Any, List, Optional

# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.chdir(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from pipeline.db_utils import get_dict_cursor, execute_query


# Editable fields - fields that can be safely modified without affecting embeddings
EDITABLE_FIELDS = {
    # Core metadata
    'title': {'type': 'str', 'max_length': 500, 'description': 'Document title'},
    'subtitle': {'type': 'str', 'max_length': 500, 'description': 'Document subtitle'},
    'author': {'type': 'str', 'max_length': 255, 'description': 'Author name (updates authors table)'},
    'publication_year': {'type': 'int', 'min': 1, 'max': 2100, 'description': 'Year of publication'},
    'publisher': {'type': 'str', 'max_length': 255, 'description': 'Publisher name'},
    'language_code': {'type': 'str', 'max_length': 10, 'description': 'ISO language code (e.g., en, de, fr)'},
    'edition_type': {'type': 'str', 'max_length': 50, 'description': 'Edition type (original, translation, etc.)'},

    # Translation metadata
    'translator': {'type': 'str', 'max_length': 255, 'description': 'Translator name'},
    'original_title': {'type': 'str', 'max_length': 500, 'description': 'Original title (for translations)'},
    'original_language': {'type': 'str', 'max_length': 10, 'description': 'Original language code'},

    # Classification
    'primary_category': {'type': 'str', 'max_length': 100, 'description': 'Primary subject category'},
    'content_type': {'type': 'str', 'max_length': 50, 'description': 'Content type (book, article, lecture, etc.)'},
    'difficulty_level': {'type': 'str', 'max_length': 20, 'description': 'Difficulty (introductory, intermediate, advanced)'},

    # Citation data
    'bibtex_key': {'type': 'str', 'max_length': 100, 'description': 'BibTeX citation key'},
    'bibtex_type': {'type': 'str', 'max_length': 50, 'description': 'BibTeX entry type (book, article, etc.)'},

    # Flags
    'needs_review': {'type': 'bool', 'description': 'Flag for manual review needed'},
    'has_index': {'type': 'bool', 'description': 'Document has an index'},
    'has_bibliography': {'type': 'bool', 'description': 'Document has a bibliography'},

    # Notes
    'notes': {'type': 'str', 'max_length': None, 'description': 'Free-form notes about the document'},
}

# Fields that should NOT be edited (would break pipeline)
PROTECTED_FIELDS = [
    'document_id',      # Primary key
    'content_hash',     # Integrity check
    'source_file',      # Original filename
    'file_path',        # File location
    'word_count',       # Derived from content
    'page_count',       # Derived from content
    'chapter_count',    # Derived from content
    'image_count',      # Derived from content
    'processing_status', # Pipeline state
    'pipeline_version', # Pipeline state
    'created_at',       # Timestamp
    'updated_at',       # Auto-updated
    'author_id',        # Use 'author' field instead
]

# Fields that may affect search results when modified
# Changes to these fields can cause metadata-embedding mismatch
SEARCHABLE_METADATA_FIELDS = [
    'title',            # Often embedded in chunk metadata
    'subtitle',         # May be embedded
    'author',           # May be embedded in context
    'primary_category', # Used in filtered searches
    'content_type',     # Used in filtered searches
]

# Warning message for metadata-embedding mismatch
EMBEDDING_MISMATCH_WARNING = """
⚠️  METADATA-EMBEDDING MISMATCH WARNING
═══════════════════════════════════════════════════════════════════

You've updated field(s) that may affect search results:
  {fields}

These fields are stored in:
  1. Document metadata (PostgreSQL) ✓ Updated
  2. Chunk embeddings (vector DB)   ✗ NOT Updated

This can cause search inconsistencies where:
  - Metadata filters may not match embedded content
  - Hybrid search may return unexpected results
  - BM25 search may find old values in chunk text

RECOMMENDED ACTIONS:
  1. For minor corrections: No action needed
  2. For significant changes: Re-embed the document:
     python embed_chunks.py --document {doc_id} --force

  To suppress this warning, use: --no-embed-warning
═══════════════════════════════════════════════════════════════════
"""


def generate_citation_key(author: str, title: str, year: int) -> str:
    """
    Generate a BibTeX-style citation key from metadata.

    Format: lastname_shorttitle_year
    Example: steiner_philosophy_1894

    Args:
        author: Author name
        title: Document title
        year: Publication year

    Returns:
        Citation key string
    """
    import re

    # Extract last name (handle "First Last" and "Last, First")
    if author:
        author = author.strip()
        if ',' in author:
            lastname = author.split(',')[0].strip()
        else:
            parts = author.split()
            lastname = parts[-1] if parts else 'unknown'
    else:
        lastname = 'unknown'

    # Clean lastname
    lastname = re.sub(r'[^a-zA-Z]', '', lastname).lower()

    # Get first significant word from title (skip articles)
    if title:
        skip_words = {'the', 'a', 'an', 'of', 'and', 'in', 'on', 'to', 'for'}
        words = title.lower().split()
        title_word = 'untitled'
        for word in words:
            clean = re.sub(r'[^a-zA-Z]', '', word)
            if clean and clean not in skip_words:
                title_word = clean[:15]  # Limit length
                break
    else:
        title_word = 'untitled'

    # Year
    year_str = str(year) if year else 'nd'

    return f"{lastname}_{title_word}_{year_str}"


def generate_citation_keys_batch(overwrite: bool = False) -> Dict[str, Any]:
    """
    Generate citation keys for all documents missing them.

    Args:
        overwrite: If True, regenerate even if key exists

    Returns:
        Dict with stats: generated, skipped, errors
    """
    stats = {'generated': 0, 'skipped': 0, 'errors': 0}

    # Get documents
    if overwrite:
        query = """
            SELECT d.document_id, d.title, a.name as author, d.publication_year
            FROM documents d
            LEFT JOIN authors a ON d.author_id = a.author_id
        """
    else:
        query = """
            SELECT d.document_id, d.title, a.name as author, d.publication_year
            FROM documents d
            LEFT JOIN authors a ON d.author_id = a.author_id
            WHERE d.bibtex_key IS NULL OR d.bibtex_key = ''
        """

    docs = execute_query(query, fetch='all')

    if not docs:
        return stats

    with get_dict_cursor() as (conn, cur):
        for doc in docs:
            try:
                key = generate_citation_key(
                    doc.get('author'),
                    doc.get('title'),
                    doc.get('publication_year')
                )

                cur.execute(
                    "UPDATE documents SET bibtex_key = %s, updated_at = CURRENT_TIMESTAMP WHERE document_id = %s",
                    (key, doc['document_id'])
                )
                stats['generated'] += 1

            except Exception as e:
                stats['errors'] += 1

    return stats


def get_document(doc_id: str) -> Optional[Dict]:
    """Fetch document by ID."""
    query = """
        SELECT d.*, a.name as author_name
        FROM documents d
        LEFT JOIN authors a ON d.author_id = a.author_id
        WHERE d.document_id = %s
    """
    return execute_query(query, (doc_id,), fetch='one')


def get_or_create_author(author_name: str) -> int:
    """Get existing author ID or create new author."""
    with get_dict_cursor() as (conn, cur):
        # Check if author exists
        cur.execute("SELECT author_id FROM authors WHERE name = %s", (author_name,))
        result = cur.fetchone()

        if result:
            return result['author_id']

        # Create new author
        cur.execute(
            "INSERT INTO authors (name) VALUES (%s) RETURNING author_id",
            (author_name,)
        )
        return cur.fetchone()['author_id']


def validate_field(field: str, value: Any) -> tuple:
    """
    Validate a field value.
    Returns (is_valid, error_message, converted_value)
    """
    if field not in EDITABLE_FIELDS:
        return False, f"Field '{field}' is not editable", None

    spec = EDITABLE_FIELDS[field]

    # Handle None/null
    if value is None or value == '' or value == 'null' or value == 'NULL':
        return True, None, None

    # Type conversion and validation
    if spec['type'] == 'str':
        value = str(value)
        max_len = spec.get('max_length')
        if max_len and len(value) > max_len:
            return False, f"Value exceeds max length of {max_len}", None
        return True, None, value

    elif spec['type'] == 'int':
        try:
            value = int(value)
            min_val = spec.get('min')
            max_val = spec.get('max')
            if min_val is not None and value < min_val:
                return False, f"Value must be >= {min_val}", None
            if max_val is not None and value > max_val:
                return False, f"Value must be <= {max_val}", None
            return True, None, value
        except ValueError:
            return False, f"Value must be an integer", None

    elif spec['type'] == 'bool':
        if isinstance(value, bool):
            return True, None, value
        if str(value).lower() in ('true', '1', 'yes', 'y'):
            return True, None, True
        if str(value).lower() in ('false', '0', 'no', 'n'):
            return True, None, False
        return False, "Value must be true/false", None

    return False, f"Unknown type {spec['type']}", None


def check_searchable_fields_changed(changes: List[tuple]) -> List[str]:
    """
    Check if any searchable metadata fields were changed.

    Args:
        changes: List of (field, old_value, new_value) tuples

    Returns:
        List of searchable field names that were changed
    """
    changed_searchable = []
    for field, old_val, new_val in changes:
        if field in SEARCHABLE_METADATA_FIELDS:
            changed_searchable.append(field)
    return changed_searchable


def update_document(doc_id: str, updates: Dict[str, Any], dry_run: bool = False) -> Dict:
    """
    Update document metadata.

    Returns dict with:
        - success: bool
        - changes: list of (field, old_value, new_value)
        - errors: list of error messages
        - searchable_changed: list of searchable fields that were modified
    """
    result = {
        'success': False,
        'changes': [],
        'errors': [],
        'searchable_changed': []
    }

    # Get current document
    doc = get_document(doc_id)
    if not doc:
        result['errors'].append(f"Document {doc_id} not found")
        return result

    # Validate all updates first
    validated_updates = {}
    for field, value in updates.items():
        if field in PROTECTED_FIELDS:
            result['errors'].append(f"Field '{field}' is protected and cannot be edited")
            continue

        is_valid, error, converted = validate_field(field, value)
        if not is_valid:
            result['errors'].append(f"{field}: {error}")
        else:
            validated_updates[field] = converted

    if result['errors']:
        return result

    # Handle special case: author field updates authors table
    author_id = None
    if 'author' in validated_updates:
        author_name = validated_updates.pop('author')
        if author_name:
            author_id = get_or_create_author(author_name)
            validated_updates['author_id'] = author_id
            result['changes'].append(('author', doc.get('author_name'), author_name))

    # Track changes
    for field, new_value in validated_updates.items():
        if field == 'author_id':
            continue  # Already handled above
        old_value = doc.get(field)
        if old_value != new_value:
            result['changes'].append((field, old_value, new_value))

    if not result['changes']:
        result['success'] = True
        result['message'] = "No changes needed"
        return result

    if dry_run:
        result['success'] = True
        result['message'] = "Dry run - no changes applied"
        return result

    # Apply updates
    if validated_updates:
        set_clauses = []
        values = []
        for field, value in validated_updates.items():
            set_clauses.append(f"{field} = %s")
            values.append(value)

        # Always update updated_at
        set_clauses.append("updated_at = CURRENT_TIMESTAMP")

        query = f"""
            UPDATE documents
            SET {', '.join(set_clauses)}
            WHERE document_id = %s
        """
        values.append(doc_id)

        with get_dict_cursor() as (conn, cur):
            cur.execute(query, tuple(values))

    result['success'] = True
    result['message'] = f"Updated {len(result['changes'])} field(s)"

    # Check for searchable field changes that may cause embedding mismatch
    result['searchable_changed'] = check_searchable_fields_changed(result['changes'])

    return result


def show_document(doc_id: str):
    """Display current document metadata."""
    doc = get_document(doc_id)
    if not doc:
        print(f"Error: Document {doc_id} not found")
        return

    print(f"\n{'='*60}")
    print(f"Document: {doc_id}")
    print(f"{'='*60}")

    # Group fields
    core_fields = ['title', 'subtitle', 'author_name', 'publication_year', 'publisher', 'language_code']
    translation_fields = ['translator', 'original_title', 'original_language', 'edition_type']
    classification_fields = ['primary_category', 'content_type', 'difficulty_level']
    citation_fields = ['bibtex_key', 'bibtex_type']
    flag_fields = ['needs_review', 'has_index', 'has_bibliography']

    print("\n--- Core Metadata ---")
    for field in core_fields:
        value = doc.get(field) or doc.get(field.replace('_name', ''))
        print(f"  {field:20}: {value or '(not set)'}")

    print("\n--- Translation Info ---")
    for field in translation_fields:
        value = doc.get(field)
        print(f"  {field:20}: {value or '(not set)'}")

    print("\n--- Classification ---")
    for field in classification_fields:
        value = doc.get(field)
        print(f"  {field:20}: {value or '(not set)'}")

    print("\n--- Citation ---")
    for field in citation_fields:
        value = doc.get(field)
        print(f"  {field:20}: {value or '(not set)'}")

    print("\n--- Flags ---")
    for field in flag_fields:
        value = doc.get(field)
        print(f"  {field:20}: {value}")

    if doc.get('notes'):
        print(f"\n--- Notes ---")
        print(f"  {doc['notes']}")

    print(f"\n--- System Info ---")
    print(f"  {'source_file':20}: {doc.get('source_file') or '(not set)'}")
    print(f"  {'processing_status':20}: {doc.get('processing_status')}")
    print(f"  {'created_at':20}: {doc.get('created_at')}")
    print(f"  {'updated_at':20}: {doc.get('updated_at')}")
    print()


def interactive_edit(doc_id: str):
    """Interactive editing mode."""
    doc = get_document(doc_id)
    if not doc:
        print(f"Error: Document {doc_id} not found")
        return

    print(f"\n{'='*60}")
    print(f"Interactive Edit: {doc_id}")
    print(f"{'='*60}")
    print("\nEditable fields:")
    for i, (field, spec) in enumerate(EDITABLE_FIELDS.items(), 1):
        current = doc.get(field) if field != 'author' else doc.get('author_name')
        print(f"  {i:2}. {field:20} = {current or '(not set)'}")
        print(f"      {spec['description']}")

    print("\nCommands:")
    print("  set <field> <value>  - Set a field value")
    print("  clear <field>        - Clear a field")
    print("  show                 - Show current values")
    print("  save                 - Save changes and exit")
    print("  quit                 - Exit without saving")
    print()

    updates = {}

    while True:
        try:
            cmd = input("edit> ").strip()
        except (EOFError, KeyboardInterrupt):
            print("\nExiting without saving.")
            return

        if not cmd:
            continue

        parts = cmd.split(maxsplit=2)
        action = parts[0].lower()

        if action == 'quit' or action == 'q':
            print("Exiting without saving.")
            return

        elif action == 'save' or action == 's':
            if not updates:
                print("No changes to save.")
                return

            result = update_document(doc_id, updates)
            if result['success']:
                print(f"\nSaved {len(result['changes'])} change(s):")
                for field, old, new in result['changes']:
                    print(f"  {field}: {old} -> {new}")
            else:
                print(f"\nErrors:")
                for error in result['errors']:
                    print(f"  {error}")
            return

        elif action == 'show':
            show_document(doc_id)
            if updates:
                print("Pending changes:")
                for field, value in updates.items():
                    print(f"  {field} = {value}")

        elif action == 'set':
            if len(parts) < 3:
                print("Usage: set <field> <value>")
                continue
            field = parts[1]
            value = parts[2]

            is_valid, error, converted = validate_field(field, value)
            if not is_valid:
                print(f"Error: {error}")
            else:
                updates[field] = converted
                print(f"  {field} = {converted} (pending)")

        elif action == 'clear':
            if len(parts) < 2:
                print("Usage: clear <field>")
                continue
            field = parts[1]
            if field not in EDITABLE_FIELDS:
                print(f"Error: Unknown field '{field}'")
            else:
                updates[field] = None
                print(f"  {field} = (cleared, pending)")

        else:
            print(f"Unknown command: {action}")


def batch_edit(json_file: str, dry_run: bool = False, suppress_embed_warning: bool = False):
    """
    Batch edit from JSON file.

    File format:
    {
        "DOC_ID_1": {"title": "New Title", "author": "New Author"},
        "DOC_ID_2": {"publication_year": 1920}
    }
    """
    try:
        with open(json_file, 'r') as f:
            edits = json.load(f)
    except FileNotFoundError:
        print(f"Error: File {json_file} not found")
        return
    except json.JSONDecodeError as e:
        print(f"Error: Invalid JSON - {e}")
        return

    print(f"\n{'='*60}")
    print(f"Batch Edit: {len(edits)} document(s)")
    if dry_run:
        print("(DRY RUN - no changes will be applied)")
    print(f"{'='*60}\n")

    success_count = 0
    error_count = 0
    docs_with_searchable_changes = []  # Track docs that need re-embedding

    for doc_id, updates in edits.items():
        result = update_document(doc_id, updates, dry_run=dry_run)

        if result['success']:
            success_count += 1
            if result['changes']:
                print(f"[OK] {doc_id}: {len(result['changes'])} change(s)")
                for field, old, new in result['changes']:
                    print(f"      {field}: {old} -> {new}")
                # Track searchable changes
                if result['searchable_changed']:
                    docs_with_searchable_changes.append((doc_id, result['searchable_changed']))
            else:
                print(f"[OK] {doc_id}: No changes needed")
        else:
            error_count += 1
            print(f"[ERROR] {doc_id}:")
            for error in result['errors']:
                print(f"      {error}")

    print(f"\n{'='*60}")
    print(f"Results: {success_count} successful, {error_count} failed")
    if dry_run:
        print("(DRY RUN - no changes were applied)")
    print(f"{'='*60}\n")

    # Show batch warning for searchable field changes
    if docs_with_searchable_changes and not suppress_embed_warning and not dry_run:
        print("\n" + "=" * 65)
        print("METADATA-EMBEDDING MISMATCH WARNING")
        print("=" * 65)
        print(f"\n{len(docs_with_searchable_changes)} document(s) had searchable fields updated:")
        for doc_id, fields in docs_with_searchable_changes:
            print(f"  {doc_id}: {', '.join(fields)}")
        print("\nTo update embeddings, run:")
        if len(docs_with_searchable_changes) <= 5:
            for doc_id, _ in docs_with_searchable_changes:
                print(f"  python embed_chunks.py --document {doc_id} --force")
        else:
            print("  # For each document, or re-embed all:")
            print("  python embed_chunks.py --all --force")
        print("\nTo suppress this warning, use: --no-embed-warning")
        print("=" * 65 + "\n")


def list_review_documents():
    """List documents flagged for review."""
    query = """
        SELECT d.document_id, d.title, a.name as author, d.needs_review, d.notes
        FROM documents d
        LEFT JOIN authors a ON d.author_id = a.author_id
        WHERE d.needs_review = true
        ORDER BY d.updated_at DESC
    """
    docs = execute_query(query)

    if not docs:
        print("\nNo documents flagged for review.")
        return

    print(f"\n{'='*60}")
    print(f"Documents Needing Review: {len(docs)}")
    print(f"{'='*60}\n")

    for doc in docs:
        print(f"ID: {doc['document_id']}")
        print(f"  Title:  {doc['title']}")
        print(f"  Author: {doc['author'] or '(not set)'}")
        if doc['notes']:
            print(f"  Notes:  {doc['notes'][:100]}...")
        print()


def list_editable_fields():
    """List all editable fields with descriptions."""
    print(f"\n{'='*60}")
    print("Editable Metadata Fields")
    print(f"{'='*60}\n")

    for field, spec in EDITABLE_FIELDS.items():
        type_info = spec['type']
        if spec.get('max_length'):
            type_info += f" (max {spec['max_length']})"
        if spec.get('min') is not None or spec.get('max') is not None:
            type_info += f" ({spec.get('min', '')}-{spec.get('max', '')})"

        print(f"  {field:20} [{type_info}]")
        print(f"      {spec['description']}")

    print(f"\n--- Protected Fields (cannot be edited) ---")
    for field in PROTECTED_FIELDS:
        print(f"  {field}")
    print()


def main():
    parser = argparse.ArgumentParser(
        description='Edit document metadata without re-ingesting',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Show document metadata
  python edit_metadata.py DOC123 --show

  # Edit single field
  python edit_metadata.py DOC123 --set title "Corrected Title"

  # Edit multiple fields
  python edit_metadata.py DOC123 --set title "New Title" --set author "Author Name"

  # Interactive mode
  python edit_metadata.py DOC123 --interactive

  # Batch edit from JSON
  python edit_metadata.py --batch edits.json

  # Dry run (preview changes)
  python edit_metadata.py DOC123 --set title "Test" --dry-run

  # List documents needing review
  python edit_metadata.py --list-review

  # List editable fields
  python edit_metadata.py --list-fields
"""
    )

    parser.add_argument('document_id', nargs='?', help='Document ID to edit')
    parser.add_argument('--show', '-s', action='store_true', help='Show current metadata')
    parser.add_argument('--set', nargs=2, action='append', metavar=('FIELD', 'VALUE'),
                       help='Set a field value (can be used multiple times)')
    parser.add_argument('--interactive', '-i', action='store_true', help='Interactive edit mode')
    parser.add_argument('--batch', metavar='FILE', help='Batch edit from JSON file')
    parser.add_argument('--dry-run', '-n', action='store_true', help='Preview changes without applying')
    parser.add_argument('--list-review', action='store_true', help='List documents needing review')
    parser.add_argument('--list-fields', action='store_true', help='List all editable fields')
    parser.add_argument('--generate-keys', action='store_true',
                       help='Generate citation keys (bibtex_key) for all documents')
    parser.add_argument('--overwrite-keys', action='store_true',
                       help='Regenerate citation keys even if they exist (use with --generate-keys)')
    parser.add_argument('--no-embed-warning', action='store_true',
                       help='Suppress warning about metadata-embedding mismatch')

    args = parser.parse_args()

    # Handle listing options (no document ID needed)
    if args.list_review:
        list_review_documents()
        return

    if args.list_fields:
        list_editable_fields()
        return

    if args.generate_keys:
        print("\nGenerating citation keys...")
        stats = generate_citation_keys_batch(overwrite=args.overwrite_keys)
        print(f"Generated: {stats['generated']}")
        print(f"Errors: {stats['errors']}")
        if args.overwrite_keys:
            print("(Overwrite mode - all keys regenerated)")
        return

    if args.batch:
        batch_edit(args.batch, dry_run=args.dry_run, suppress_embed_warning=args.no_embed_warning)
        return

    # All other operations require document ID
    if not args.document_id:
        parser.print_help()
        print("\nError: Document ID required (or use --list-review, --list-fields, --batch)")
        sys.exit(1)

    doc_id = args.document_id

    if args.show:
        show_document(doc_id)

    elif args.interactive:
        interactive_edit(doc_id)

    elif args.set:
        updates = {field: value for field, value in args.set}
        result = update_document(doc_id, updates, dry_run=args.dry_run)

        if result['success']:
            if result['changes']:
                prefix = "[DRY RUN] " if args.dry_run else ""
                print(f"\n{prefix}Updated {doc_id}:")
                for field, old, new in result['changes']:
                    print(f"  {field}: {old} -> {new}")

                # Show warning for searchable field changes (unless suppressed)
                if result['searchable_changed'] and not args.no_embed_warning and not args.dry_run:
                    fields_str = ', '.join(result['searchable_changed'])
                    print(EMBEDDING_MISMATCH_WARNING.format(
                        fields=fields_str,
                        doc_id=doc_id
                    ))
            else:
                print(f"\nNo changes needed for {doc_id}")
        else:
            print(f"\nErrors for {doc_id}:")
            for error in result['errors']:
                print(f"  {error}")
            sys.exit(1)

    else:
        # Default to showing document
        show_document(doc_id)


if __name__ == '__main__':
    main()
