#!/usr/bin/env python3
"""
Citation Generation Tool

Generate properly formatted citations from document references.
Designed as a tool for Claude Code to use when finalizing manuscripts.

Usage:
    # Single citation
    python cite.py DOC_023 --page 45 --style chicago

    # Batch mode - process a chapter file
    python cite.py --chapter chapter_03.md --style chicago

    # Generate bibliography
    python cite.py --bibliography BOOK_xxx --style chicago

Output:
    Formatted citations in the requested style (Chicago, MLA, APA, or simple)
"""

import argparse
import json
import logging
import re
import sys
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass

# Setup path for imports
sys.path.insert(0, str(Path(__file__).parent))

from config import LOGGING_CONFIG
from db_utils import execute_query

# Setup logging
LOG_LEVEL = LOGGING_CONFIG.get('level', 'INFO')
logging.basicConfig(
    level=getattr(logging, LOG_LEVEL, logging.INFO),
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


# =============================================================================
# DATA STRUCTURES
# =============================================================================

@dataclass
class DocumentMetadata:
    """Metadata for generating citations."""
    document_id: str
    title: str
    author: str
    publication_year: Optional[int]
    publisher: Optional[str] = None
    place_of_publication: Optional[str] = None
    editor: Optional[str] = None
    translator: Optional[str] = None
    edition: Optional[str] = None
    volume: Optional[str] = None
    url: Optional[str] = None
    accessed_date: Optional[str] = None

    # Quality flags
    metadata_complete: bool = True
    missing_fields: List[str] = None

    def __post_init__(self):
        if self.missing_fields is None:
            self.missing_fields = []
        # Check for missing essential fields
        if not self.author or self.author == 'Unknown':
            self.missing_fields.append('author')
            self.metadata_complete = False
        if not self.title:
            self.missing_fields.append('title')
            self.metadata_complete = False
        if not self.publication_year:
            self.missing_fields.append('publication_year')


@dataclass
class Citation:
    """A formatted citation."""
    document_id: str
    style: str
    footnote: str
    bibliography: str
    short_form: str  # For subsequent references
    page: Optional[str] = None
    warnings: List[str] = None

    def __post_init__(self):
        if self.warnings is None:
            self.warnings = []


# =============================================================================
# DATABASE LOOKUP
# =============================================================================

def get_document_metadata(document_id: str) -> Optional[DocumentMetadata]:
    """Fetch document metadata from database."""
    query = """
        SELECT
            d.document_id,
            d.title,
            a.name as author,
            d.publication_year,
            d.publisher,
            d.translator,
            d.edition_type,
            d.primary_category
        FROM documents d
        LEFT JOIN authors a ON d.author_id = a.author_id
        WHERE d.document_id = %s
    """

    try:
        result = execute_query(query, (document_id,), fetch='one')
        if not result:
            return None

        return DocumentMetadata(
            document_id=result['document_id'],
            title=result['title'] or 'Untitled',
            author=result['author'] or 'Unknown',
            publication_year=result['publication_year'],
            publisher=result.get('publisher'),
            translator=result.get('translator'),
            edition=result.get('edition_type'),
        )
    except Exception as e:
        logger.error(f"Database error: {e}")
        return None


def get_multiple_documents(document_ids: List[str]) -> Dict[str, DocumentMetadata]:
    """Fetch metadata for multiple documents."""
    if not document_ids:
        return {}

    placeholders = ', '.join(['%s'] * len(document_ids))
    query = f"""
        SELECT
            d.document_id,
            d.title,
            a.name as author,
            d.publication_year,
            d.publisher,
            d.translator,
            d.edition_type
        FROM documents d
        LEFT JOIN authors a ON d.author_id = a.author_id
        WHERE d.document_id IN ({placeholders})
    """

    try:
        results = execute_query(query, tuple(document_ids), fetch='all')
        metadata = {}
        for r in results:
            metadata[r['document_id']] = DocumentMetadata(
                document_id=r['document_id'],
                title=r['title'] or 'Untitled',
                author=r['author'] or 'Unknown',
                publication_year=r['publication_year'],
                publisher=r.get('publisher'),
                translator=r.get('translator'),
                edition=r.get('edition_type'),
            )
        return metadata
    except Exception as e:
        logger.error(f"Database error: {e}")
        return {}


# =============================================================================
# CITATION FORMATTERS
# =============================================================================

def format_author_chicago(author: str, inverted: bool = True) -> str:
    """Format author name for Chicago style."""
    if not author or author == 'Unknown':
        return 'Unknown'

    # Handle multiple authors
    if ' and ' in author:
        parts = author.split(' and ')
        if inverted:
            # First author inverted, rest normal
            first = format_author_chicago(parts[0], inverted=True)
            rest = ' and '.join(parts[1:])
            return f"{first}, and {rest}"
        return author

    # Single author - try to invert (Last, First)
    if inverted and ',' not in author:
        parts = author.strip().split()
        if len(parts) >= 2:
            return f"{parts[-1]}, {' '.join(parts[:-1])}"

    return author


def format_chicago_footnote(meta: DocumentMetadata, page: Optional[str] = None) -> str:
    """Format a Chicago-style footnote citation."""
    parts = []

    # Author (normal order for footnotes)
    if meta.author and meta.author != 'Unknown':
        parts.append(meta.author)

    # Title (italicized for books)
    if meta.title:
        parts.append(f"*{meta.title}*")

    # Publication info
    pub_info = []
    if meta.place_of_publication:
        pub_info.append(meta.place_of_publication)
    if meta.publisher:
        pub_info.append(meta.publisher)
    if meta.publication_year:
        pub_info.append(str(meta.publication_year))

    if pub_info:
        parts.append(f"({', '.join(pub_info)})")
    elif meta.publication_year:
        parts.append(f"({meta.publication_year})")

    # Page number
    if page:
        parts.append(page)

    return ', '.join(parts) + '.'


def format_chicago_bibliography(meta: DocumentMetadata) -> str:
    """Format a Chicago-style bibliography entry."""
    parts = []

    # Author (inverted for bibliography)
    author = format_author_chicago(meta.author, inverted=True)
    parts.append(author)

    # Title
    if meta.title:
        parts.append(f"*{meta.title}*.")

    # Publication info
    if meta.place_of_publication and meta.publisher:
        parts.append(f"{meta.place_of_publication}: {meta.publisher},")
    elif meta.publisher:
        parts.append(f"{meta.publisher},")

    if meta.publication_year:
        parts.append(f"{meta.publication_year}.")

    return ' '.join(parts)


def format_chicago_short(meta: DocumentMetadata, page: Optional[str] = None) -> str:
    """Format a Chicago short-form citation (for subsequent references)."""
    parts = []

    # Author last name only
    if meta.author and meta.author != 'Unknown':
        author_parts = meta.author.split()
        if author_parts:
            parts.append(author_parts[-1])

    # Short title (first few words)
    if meta.title:
        title_words = meta.title.split()[:3]
        short_title = ' '.join(title_words)
        if len(meta.title.split()) > 3:
            short_title += '...'
        parts.append(f"*{short_title}*")

    if page:
        parts.append(page)

    return ', '.join(parts) + '.'


def format_mla(meta: DocumentMetadata, page: Optional[str] = None) -> Tuple[str, str]:
    """Format MLA-style citation (returns footnote and works cited entry)."""
    # Works Cited entry
    parts = []

    # Author (inverted)
    author = format_author_chicago(meta.author, inverted=True)
    parts.append(f"{author}.")

    # Title (italicized)
    if meta.title:
        parts.append(f"*{meta.title}*.")

    # Publisher and year
    if meta.publisher:
        parts.append(f"{meta.publisher},")
    if meta.publication_year:
        parts.append(f"{meta.publication_year}.")

    works_cited = ' '.join(parts)

    # In-text citation (Author Page)
    author_last = meta.author.split()[-1] if meta.author and meta.author != 'Unknown' else 'Unknown'
    if page:
        in_text = f"({author_last} {page})"
    else:
        in_text = f"({author_last})"

    return in_text, works_cited


def format_apa(meta: DocumentMetadata, page: Optional[str] = None) -> Tuple[str, str]:
    """Format APA-style citation (returns in-text and reference entry)."""
    # Reference entry
    parts = []

    # Author (Last, F. M.)
    if meta.author and meta.author != 'Unknown':
        author_parts = meta.author.split()
        if len(author_parts) >= 2:
            initials = '. '.join([p[0] for p in author_parts[:-1]]) + '.'
            parts.append(f"{author_parts[-1]}, {initials}")
        else:
            parts.append(meta.author)

    # Year
    if meta.publication_year:
        parts.append(f"({meta.publication_year}).")

    # Title (italicized, sentence case)
    if meta.title:
        parts.append(f"*{meta.title}*.")

    # Publisher
    if meta.publisher:
        parts.append(f"{meta.publisher}.")

    reference = ' '.join(parts)

    # In-text citation (Author, Year, p. X)
    author_last = meta.author.split()[-1] if meta.author and meta.author != 'Unknown' else 'Unknown'
    year = meta.publication_year or 'n.d.'
    if page:
        in_text = f"({author_last}, {year}, p. {page})"
    else:
        in_text = f"({author_last}, {year})"

    return in_text, reference


def format_simple(meta: DocumentMetadata, page: Optional[str] = None) -> str:
    """Format a simple citation for drafts."""
    parts = [meta.author or 'Unknown']

    if meta.title:
        parts.append(f"'{meta.title}'")

    if meta.publication_year:
        parts.append(f"({meta.publication_year})")

    if page:
        parts.append(f"p.{page}")

    return ', '.join(parts)


# =============================================================================
# MAIN CITATION GENERATOR
# =============================================================================

def generate_citation(
    document_id: str,
    style: str = 'chicago',
    page: Optional[str] = None,
) -> Optional[Citation]:
    """Generate a citation for a document."""
    # Fetch metadata
    meta = get_document_metadata(document_id)
    if not meta:
        logger.warning(f"Document not found: {document_id}")
        return None

    warnings = []
    if not meta.metadata_complete:
        warnings.append(f"Missing fields: {', '.join(meta.missing_fields)}")

    # Generate based on style
    if style == 'chicago':
        footnote = format_chicago_footnote(meta, page)
        bibliography = format_chicago_bibliography(meta)
        short_form = format_chicago_short(meta, page)
    elif style == 'mla':
        footnote, bibliography = format_mla(meta, page)
        short_form = footnote
    elif style == 'apa':
        footnote, bibliography = format_apa(meta, page)
        short_form = footnote
    else:  # simple
        footnote = format_simple(meta, page)
        bibliography = footnote
        short_form = footnote

    return Citation(
        document_id=document_id,
        style=style,
        footnote=footnote,
        bibliography=bibliography,
        short_form=short_form,
        page=page,
        warnings=warnings,
    )


# =============================================================================
# BATCH PROCESSING
# =============================================================================

def extract_document_references(text: str) -> List[Tuple[str, Optional[str]]]:
    """Extract document references from text (DOC_xxx patterns)."""
    # Pattern: DOC_xxx optionally followed by page reference
    pattern = r'\b(DOC_[a-zA-Z0-9_]+)(?:\s*[,\s]*(?:p\.?\s*|page\s*)(\d+(?:-\d+)?))?'

    matches = re.findall(pattern, text, re.IGNORECASE)

    # Deduplicate while preserving order
    seen = set()
    results = []
    for doc_id, page in matches:
        doc_id = doc_id.upper()
        if doc_id not in seen:
            seen.add(doc_id)
            results.append((doc_id, page if page else None))

    return results


def process_chapter(
    chapter_path: Path,
    style: str = 'chicago',
) -> Dict[str, Any]:
    """Process a chapter file and generate citations for all references."""
    text = chapter_path.read_text(encoding='utf-8')

    # Extract references
    references = extract_document_references(text)

    if not references:
        return {
            'chapter': chapter_path.name,
            'citations': [],
            'bibliography': [],
            'warnings': ['No document references (DOC_xxx) found in chapter'],
        }

    # Generate citations
    citations = []
    bibliography = []
    warnings = []

    # Fetch all metadata at once
    doc_ids = [ref[0] for ref in references]
    all_metadata = get_multiple_documents(doc_ids)

    for doc_id, page in references:
        citation = generate_citation(doc_id, style, page)
        if citation:
            citations.append({
                'document_id': doc_id,
                'page': page,
                'footnote': citation.footnote,
                'short_form': citation.short_form,
            })
            bibliography.append({
                'document_id': doc_id,
                'entry': citation.bibliography,
            })
            if citation.warnings:
                warnings.extend([f"{doc_id}: {w}" for w in citation.warnings])
        else:
            warnings.append(f"{doc_id}: Document not found in database")

    # Sort bibliography alphabetically
    bibliography.sort(key=lambda x: x['entry'])

    return {
        'chapter': chapter_path.name,
        'citations': citations,
        'bibliography': bibliography,
        'warnings': warnings,
    }


def generate_bibliography(
    project_id: str,
    style: str = 'chicago',
) -> Dict[str, Any]:
    """Generate a bibliography for an entire project."""
    from config import BOOK_WORKFLOW_CONFIG

    project_dir = BOOK_WORKFLOW_CONFIG['projects_dir'] / project_id
    sources_path = project_dir / 'sources.json'

    if not sources_path.exists():
        return {
            'project_id': project_id,
            'entries': [],
            'warnings': [f'sources.json not found in project {project_id}'],
        }

    with open(sources_path, 'r') as f:
        sources_data = json.load(f)

    # Get unique document IDs from library sources
    doc_ids = set()
    for src in sources_data.get('library_sources', []):
        if src.get('document_id'):
            doc_ids.add(src['document_id'])

    if not doc_ids:
        return {
            'project_id': project_id,
            'entries': [],
            'warnings': ['No library sources found in project'],
        }

    # Generate citations
    entries = []
    warnings = []

    all_metadata = get_multiple_documents(list(doc_ids))

    for doc_id in sorted(doc_ids):
        citation = generate_citation(doc_id, style)
        if citation:
            entries.append({
                'document_id': doc_id,
                'entry': citation.bibliography,
            })
            if citation.warnings:
                warnings.extend([f"{doc_id}: {w}" for w in citation.warnings])
        else:
            warnings.append(f"{doc_id}: Document not found")

    # Sort alphabetically
    entries.sort(key=lambda x: x['entry'])

    return {
        'project_id': project_id,
        'style': style,
        'total_sources': len(entries),
        'entries': entries,
        'warnings': warnings,
    }


# =============================================================================
# OUTPUT FORMATTING
# =============================================================================

def format_output_text(result: Dict[str, Any], mode: str = 'single') -> str:
    """Format output for display."""
    lines = []

    if mode == 'single':
        # Single citation
        lines.append(f"## Citation: {result.get('document_id', 'Unknown')}")
        lines.append("")
        lines.append("### Footnote")
        lines.append(f"{result.get('footnote', '')}")
        lines.append("")
        lines.append("### Bibliography Entry")
        lines.append(f"{result.get('bibliography', '')}")
        lines.append("")
        lines.append("### Short Form (subsequent references)")
        lines.append(f"{result.get('short_form', '')}")

        if result.get('warnings'):
            lines.append("")
            lines.append("### Warnings")
            for w in result['warnings']:
                lines.append(f"- {w}")

    elif mode == 'chapter':
        # Chapter citations
        lines.append(f"## Citations: {result.get('chapter', 'Unknown')}")
        lines.append("")

        if result.get('citations'):
            lines.append("### Footnotes")
            lines.append("")
            for i, c in enumerate(result['citations'], 1):
                lines.append(f"{i}. {c['footnote']}")
            lines.append("")

            lines.append("### Bibliography")
            lines.append("")
            for b in result['bibliography']:
                lines.append(f"- {b['entry']}")

        if result.get('warnings'):
            lines.append("")
            lines.append("### Warnings")
            for w in result['warnings']:
                lines.append(f"- {w}")

    elif mode == 'bibliography':
        # Full bibliography
        lines.append(f"## Bibliography")
        lines.append(f"**Project:** {result.get('project_id', 'Unknown')}")
        lines.append(f"**Style:** {result.get('style', 'chicago')}")
        lines.append(f"**Total Sources:** {result.get('total_sources', 0)}")
        lines.append("")

        for entry in result.get('entries', []):
            lines.append(f"- {entry['entry']}")

        if result.get('warnings'):
            lines.append("")
            lines.append("### Metadata Warnings")
            for w in result['warnings']:
                lines.append(f"- {w}")

    return '\n'.join(lines)


# =============================================================================
# CLI
# =============================================================================

def create_parser() -> argparse.ArgumentParser:
    """Create argument parser."""
    parser = argparse.ArgumentParser(
        prog='cite',
        description='Generate properly formatted citations from document references',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Single citation
  python cite.py DOC_023 --page 45 --style chicago

  # Process chapter file
  python cite.py --chapter chapter_03.md --style chicago

  # Generate project bibliography
  python cite.py --bibliography BOOK_xxx --style mla

Citation styles: chicago (default), mla, apa, simple
        """
    )

    parser.add_argument(
        'document_id',
        nargs='?',
        type=str,
        help='Document ID (DOC_xxx) to generate citation for'
    )

    parser.add_argument(
        '--page', '-p',
        type=str,
        help='Page number(s) for citation'
    )

    parser.add_argument(
        '--style', '-s',
        choices=['chicago', 'mla', 'apa', 'simple'],
        default='chicago',
        help='Citation style (default: chicago)'
    )

    parser.add_argument(
        '--chapter', '-c',
        type=str,
        help='Process a chapter file and generate all citations'
    )

    parser.add_argument(
        '--bibliography', '-b',
        type=str,
        metavar='PROJECT_ID',
        help='Generate bibliography for a book project'
    )

    parser.add_argument(
        '--format', '-f',
        choices=['text', 'json'],
        default='text',
        help='Output format (default: text)'
    )

    parser.add_argument(
        '--output', '-o',
        type=str,
        help='Output file path'
    )

    return parser


def main() -> int:
    """Main entry point."""
    parser = create_parser()
    args = parser.parse_args()

    result = None
    mode = 'single'

    # Determine operation mode
    if args.bibliography:
        # Generate project bibliography
        result = generate_bibliography(args.bibliography, args.style)
        mode = 'bibliography'

    elif args.chapter:
        # Process chapter file
        chapter_path = Path(args.chapter)
        if not chapter_path.exists():
            print(f"Error: Chapter file not found: {chapter_path}")
            return 1
        result = process_chapter(chapter_path, args.style)
        mode = 'chapter'

    elif args.document_id:
        # Single citation
        citation = generate_citation(args.document_id, args.style, args.page)
        if not citation:
            print(f"Error: Document not found: {args.document_id}")
            return 1
        result = {
            'document_id': citation.document_id,
            'style': citation.style,
            'footnote': citation.footnote,
            'bibliography': citation.bibliography,
            'short_form': citation.short_form,
            'page': citation.page,
            'warnings': citation.warnings,
        }
        mode = 'single'

    else:
        parser.print_help()
        return 1

    # Format output
    if args.format == 'json':
        output = json.dumps(result, indent=2)
    else:
        output = format_output_text(result, mode)

    # Write or print
    if args.output:
        Path(args.output).write_text(output, encoding='utf-8')
        print(f"Output written to: {args.output}")
    else:
        print(output)

    # Return code based on warnings
    if result.get('warnings'):
        return 1
    return 0


if __name__ == '__main__':
    sys.exit(main())
