#!/usr/bin/env python3
"""
Book Workflow Data Models

Data structures for CLI-first research workflow for book projects.
Supports multi-chapter books with subject-based research, gap tracking,
and checkpoint/resume capabilities.

Models:
- BookProject: Master project containing all chapters
- ChapterProject: Individual chapter with subjects and research
- WorkflowCheckpoint: Resume point for interrupted workflows
- SubjectResearch: Research results for a single subject
"""

import os
import json
import logging
import hashlib
import secrets
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field, asdict

# Setup logging
logger = logging.getLogger(__name__)

# Import existing types
try:
    from research_agent import IdentifiedGap, ResearchSession
except ImportError:
    # Define minimal types if not available
    @dataclass
    class IdentifiedGap:
        description: str
        suggested_query: str
        identified_at: str
        searched: bool = False
        search_results_count: int = 0

        def to_dict(self) -> Dict[str, Any]:
            return asdict(self)

        @classmethod
        def from_dict(cls, data: Dict[str, Any]) -> 'IdentifiedGap':
            return cls(**data)

    ResearchSession = None


# =============================================================================
# CONFIGURATION
# =============================================================================

# Default project directory
DEFAULT_PROJECTS_DIR = Path(__file__).parent.parent / 'book_projects'


def generate_project_id(prefix: str = 'BOOK') -> str:
    """Generate a unique project ID."""
    timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
    random_suffix = secrets.token_hex(3)
    return f"{prefix}_{timestamp}_{random_suffix}"


# =============================================================================
# SUBJECT RESEARCH
# =============================================================================

@dataclass
class SubjectResearch:
    """Research results for a single subject within a chapter."""
    subject: str
    query_used: str
    research_session_id: Optional[str] = None

    # Results
    chunks_found: int = 0
    documents_found: int = 0
    synthesis: str = ""

    # Gaps identified for this subject
    gaps: List[Dict[str, Any]] = field(default_factory=list)

    # Timing
    started_at: Optional[str] = None
    completed_at: Optional[str] = None
    status: str = "pending"  # pending, in_progress, completed, failed

    def to_dict(self) -> Dict[str, Any]:
        return {
            'subject': self.subject,
            'query_used': self.query_used,
            'research_session_id': self.research_session_id,
            'chunks_found': self.chunks_found,
            'documents_found': self.documents_found,
            'synthesis': self.synthesis,
            'gaps': self.gaps,
            'started_at': self.started_at,
            'completed_at': self.completed_at,
            'status': self.status,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'SubjectResearch':
        return cls(
            subject=data.get('subject', ''),
            query_used=data.get('query_used', ''),
            research_session_id=data.get('research_session_id'),
            chunks_found=data.get('chunks_found', 0),
            documents_found=data.get('documents_found', 0),
            synthesis=data.get('synthesis', ''),
            gaps=data.get('gaps', []),
            started_at=data.get('started_at'),
            completed_at=data.get('completed_at'),
            status=data.get('status', 'pending'),
        )


# =============================================================================
# CHAPTER PROJECT
# =============================================================================

@dataclass
class ChapterProject:
    """Research project for a single book chapter."""
    chapter_id: str
    chapter_number: int
    title: str
    subjects: List[str]

    # Parent reference
    book_project_id: Optional[str] = None

    # Research data
    subject_research: List[SubjectResearch] = field(default_factory=list)

    # Aggregated gaps from all subjects
    identified_gaps: List[Dict[str, Any]] = field(default_factory=list)

    # Gap filling results (query -> results)
    gap_search_results: Dict[str, Any] = field(default_factory=dict)

    # Outputs
    research_summary: str = ""
    draft_content: str = ""

    # Sources collected
    library_sources: List[Dict[str, Any]] = field(default_factory=list)
    web_sources: List[Dict[str, Any]] = field(default_factory=list)

    # Status tracking
    status: str = "pending"  # pending, researching, gaps_identified, gap_filling, synthesizing, drafting, complete
    created_at: str = field(default_factory=lambda: datetime.now().isoformat())
    updated_at: str = field(default_factory=lambda: datetime.now().isoformat())

    def __post_init__(self):
        """Initialize subject research entries if not provided."""
        if not self.subject_research and self.subjects:
            self.subject_research = [
                SubjectResearch(subject=s, query_used=s)
                for s in self.subjects
            ]

    def update_status(self, new_status: str) -> None:
        """Update chapter status and timestamp."""
        self.status = new_status
        self.updated_at = datetime.now().isoformat()

    def get_pending_subjects(self) -> List[SubjectResearch]:
        """Get subjects that haven't been researched yet."""
        return [sr for sr in self.subject_research if sr.status == 'pending']

    def get_completed_subjects(self) -> List[SubjectResearch]:
        """Get subjects that have been researched."""
        return [sr for sr in self.subject_research if sr.status == 'completed']

    def collect_gaps(self) -> List[Dict[str, Any]]:
        """Collect all gaps from subject research into identified_gaps."""
        all_gaps = []
        for sr in self.subject_research:
            for gap in sr.gaps:
                gap_with_source = {
                    **gap,
                    'source_subject': sr.subject,
                    'chapter_id': self.chapter_id,
                    'chapter_number': self.chapter_number,
                }
                all_gaps.append(gap_with_source)
        self.identified_gaps = all_gaps
        return all_gaps

    def get_unfilled_gaps(self) -> List[Dict[str, Any]]:
        """Get gaps that haven't been searched yet."""
        return [g for g in self.identified_gaps if not g.get('searched', False)]

    @property
    def progress(self) -> float:
        """Calculate chapter research progress (0.0 to 1.0)."""
        if not self.subject_research:
            return 0.0
        completed = sum(1 for sr in self.subject_research if sr.status == 'completed')
        return completed / len(self.subject_research)

    def to_dict(self) -> Dict[str, Any]:
        return {
            'chapter_id': self.chapter_id,
            'chapter_number': self.chapter_number,
            'title': self.title,
            'subjects': self.subjects,
            'book_project_id': self.book_project_id,
            'subject_research': [sr.to_dict() for sr in self.subject_research],
            'identified_gaps': self.identified_gaps,
            'gap_search_results': self.gap_search_results,
            'research_summary': self.research_summary,
            'draft_content': self.draft_content,
            'library_sources': self.library_sources,
            'web_sources': self.web_sources,
            'status': self.status,
            'created_at': self.created_at,
            'updated_at': self.updated_at,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'ChapterProject':
        chapter = cls(
            chapter_id=data.get('chapter_id', ''),
            chapter_number=data.get('chapter_number', 0),
            title=data.get('title', ''),
            subjects=data.get('subjects', []),
            book_project_id=data.get('book_project_id'),
            identified_gaps=data.get('identified_gaps', []),
            gap_search_results=data.get('gap_search_results', {}),
            research_summary=data.get('research_summary', ''),
            draft_content=data.get('draft_content', ''),
            library_sources=data.get('library_sources', []),
            web_sources=data.get('web_sources', []),
            status=data.get('status', 'pending'),
            created_at=data.get('created_at', datetime.now().isoformat()),
            updated_at=data.get('updated_at', datetime.now().isoformat()),
        )
        # Load subject research
        chapter.subject_research = [
            SubjectResearch.from_dict(sr)
            for sr in data.get('subject_research', [])
        ]
        return chapter


# =============================================================================
# BOOK PROJECT
# =============================================================================

@dataclass
class BookProject:
    """Master project for a book with multiple chapters."""
    project_id: str
    title: str
    author: str = ""
    description: str = ""

    # Chapters
    chapters: List[ChapterProject] = field(default_factory=list)

    # Phase tracking (1-5)
    current_phase: int = 1
    phase_status: Dict[int, str] = field(default_factory=lambda: {
        1: 'pending',  # Initial Research
        2: 'pending',  # Gap Analysis
        3: 'pending',  # Gap Filling
        4: 'pending',  # Synthesis
        5: 'pending',  # Draft Generation
    })

    # Aggregated data (computed from chapters)
    all_gaps: List[Dict[str, Any]] = field(default_factory=list)
    all_library_sources: List[Dict[str, Any]] = field(default_factory=list)
    all_web_sources: List[Dict[str, Any]] = field(default_factory=list)

    # Configuration
    auto_fill_gaps: bool = False
    tavily_credit_budget: int = 50
    max_iterations: int = 5
    use_graphrag: bool = True
    use_rerank: bool = True

    # Timing
    created_at: str = field(default_factory=lambda: datetime.now().isoformat())
    updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
    completed_at: Optional[str] = None

    # Storage path
    project_dir: Optional[str] = None

    def __post_init__(self):
        """Set up project directory if not provided."""
        if not self.project_dir:
            self.project_dir = str(DEFAULT_PROJECTS_DIR / self.project_id)

    @property
    def status(self) -> str:
        """Get overall project status based on phase statuses."""
        if self.completed_at:
            return 'completed'
        if self.phase_status.get(self.current_phase) == 'in_progress':
            return 'in_progress'
        if all(s == 'pending' for s in self.phase_status.values()):
            return 'pending'
        return 'in_progress'

    def set_phase(self, phase: int, status: str) -> None:
        """Update phase status."""
        self.phase_status[phase] = status
        if status == 'in_progress':
            self.current_phase = phase
        self.updated_at = datetime.now().isoformat()

    def advance_phase(self) -> int:
        """Mark current phase complete and advance to next."""
        self.phase_status[self.current_phase] = 'completed'
        if self.current_phase < 5:
            self.current_phase += 1
            self.phase_status[self.current_phase] = 'in_progress'
        self.updated_at = datetime.now().isoformat()
        return self.current_phase

    def get_chapter(self, chapter_number: int) -> Optional[ChapterProject]:
        """Get chapter by number."""
        for chapter in self.chapters:
            if chapter.chapter_number == chapter_number:
                return chapter
        return None

    def get_chapter_by_id(self, chapter_id: str) -> Optional[ChapterProject]:
        """Get chapter by ID."""
        for chapter in self.chapters:
            if chapter.chapter_id == chapter_id:
                return chapter
        return None

    def collect_all_gaps(self) -> List[Dict[str, Any]]:
        """Aggregate gaps from all chapters."""
        all_gaps = []
        for chapter in self.chapters:
            chapter.collect_gaps()
            all_gaps.extend(chapter.identified_gaps)
        self.all_gaps = all_gaps
        return all_gaps

    def collect_all_sources(self) -> None:
        """Aggregate sources from all chapters."""
        library_sources = {}
        web_sources = {}

        for chapter in self.chapters:
            # Deduplicate library sources by document_id
            for src in chapter.library_sources:
                doc_id = src.get('document_id', '')
                if doc_id and doc_id not in library_sources:
                    library_sources[doc_id] = src
                elif doc_id:
                    # Merge chapter references
                    existing = library_sources[doc_id]
                    chapters_used = set(existing.get('chapters_used', []))
                    chapters_used.add(chapter.chapter_number)
                    existing['chapters_used'] = list(chapters_used)

            # Deduplicate web sources by URL
            for src in chapter.web_sources:
                url = src.get('url', '')
                if url and url not in web_sources:
                    web_sources[url] = src
                elif url:
                    existing = web_sources[url]
                    chapters_used = set(existing.get('chapters_used', []))
                    chapters_used.add(chapter.chapter_number)
                    existing['chapters_used'] = list(chapters_used)

        self.all_library_sources = list(library_sources.values())
        self.all_web_sources = list(web_sources.values())

    @property
    def total_subjects(self) -> int:
        """Total number of subjects across all chapters."""
        return sum(len(c.subjects) for c in self.chapters)

    @property
    def completed_subjects(self) -> int:
        """Number of subjects with completed research."""
        return sum(
            len(c.get_completed_subjects())
            for c in self.chapters
        )

    @property
    def progress(self) -> float:
        """Overall project progress (0.0 to 1.0)."""
        if not self.chapters:
            return 0.0

        # Weight each phase
        phase_weights = {1: 0.3, 2: 0.1, 3: 0.2, 4: 0.2, 5: 0.2}
        total_progress = 0.0

        for phase, weight in phase_weights.items():
            if self.phase_status.get(phase) == 'completed':
                total_progress += weight
            elif self.phase_status.get(phase) == 'in_progress':
                if phase == 1:
                    # Calculate based on subject completion
                    if self.total_subjects > 0:
                        total_progress += weight * (self.completed_subjects / self.total_subjects)
                else:
                    total_progress += weight * 0.5  # Assume halfway

        return total_progress

    def get_status_summary(self) -> Dict[str, Any]:
        """Get status summary for CLI output."""
        unfilled_gaps = [g for g in self.all_gaps if not g.get('searched', False)]

        return {
            'project_id': self.project_id,
            'title': self.title,
            'status': 'completed' if self.completed_at else 'in_progress',
            'current_phase': self.current_phase,
            'phase_status': self.phase_status,
            'progress': round(self.progress, 2),
            'chapters': {
                'total': len(self.chapters),
                'completed': sum(1 for c in self.chapters if c.status == 'complete'),
                'in_progress': sum(1 for c in self.chapters if c.status not in ['pending', 'complete']),
                'pending': sum(1 for c in self.chapters if c.status == 'pending'),
            },
            'subjects': {
                'total': self.total_subjects,
                'completed': self.completed_subjects,
            },
            'gaps': {
                'total': len(self.all_gaps),
                'unfilled': len(unfilled_gaps),
                'filled': len(self.all_gaps) - len(unfilled_gaps),
            },
            'sources': {
                'library': len(self.all_library_sources),
                'web': len(self.all_web_sources),
            },
            'config': {
                'auto_fill_gaps': self.auto_fill_gaps,
                'tavily_budget': self.tavily_credit_budget,
            },
            'updated_at': self.updated_at,
        }

    def to_dict(self) -> Dict[str, Any]:
        return {
            'project_id': self.project_id,
            'title': self.title,
            'author': self.author,
            'description': self.description,
            'chapters': [c.to_dict() for c in self.chapters],
            'current_phase': self.current_phase,
            'phase_status': self.phase_status,
            'all_gaps': self.all_gaps,
            'all_library_sources': self.all_library_sources,
            'all_web_sources': self.all_web_sources,
            'auto_fill_gaps': self.auto_fill_gaps,
            'tavily_credit_budget': self.tavily_credit_budget,
            'max_iterations': self.max_iterations,
            'use_graphrag': self.use_graphrag,
            'use_rerank': self.use_rerank,
            'created_at': self.created_at,
            'updated_at': self.updated_at,
            'completed_at': self.completed_at,
            'project_dir': self.project_dir,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'BookProject':
        project = cls(
            project_id=data.get('project_id', ''),
            title=data.get('title', ''),
            author=data.get('author', ''),
            description=data.get('description', ''),
            current_phase=data.get('current_phase', 1),
            phase_status=data.get('phase_status', {1: 'pending', 2: 'pending', 3: 'pending', 4: 'pending', 5: 'pending'}),
            all_gaps=data.get('all_gaps', []),
            all_library_sources=data.get('all_library_sources', []),
            all_web_sources=data.get('all_web_sources', []),
            auto_fill_gaps=data.get('auto_fill_gaps', False),
            tavily_credit_budget=data.get('tavily_credit_budget', 50),
            max_iterations=data.get('max_iterations', 5),
            use_graphrag=data.get('use_graphrag', True),
            use_rerank=data.get('use_rerank', True),
            created_at=data.get('created_at', datetime.now().isoformat()),
            updated_at=data.get('updated_at', datetime.now().isoformat()),
            completed_at=data.get('completed_at'),
            project_dir=data.get('project_dir'),
        )
        # Load chapters
        project.chapters = [
            ChapterProject.from_dict(c)
            for c in data.get('chapters', [])
        ]
        return project

    def save(self, path: Optional[Path] = None) -> Path:
        """Save project to JSON file."""
        if path is None:
            project_dir = Path(self.project_dir)
            project_dir.mkdir(parents=True, exist_ok=True)
            path = project_dir / 'project.json'
        else:
            path = Path(path)
            # If path is a directory or doesn't end with .json, treat as directory
            if path.is_dir() or (path.exists() and path.is_dir()) or not str(path).endswith('.json'):
                path.mkdir(parents=True, exist_ok=True)
                path = path / 'project.json'

        with open(path, 'w', encoding='utf-8') as f:
            json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

        logger.info(f"Saved book project to {path}")
        return path

    @classmethod
    def load(cls, path: Path) -> 'BookProject':
        """Load project from JSON file."""
        with open(path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return cls.from_dict(data)


# =============================================================================
# WORKFLOW CHECKPOINT
# =============================================================================

@dataclass
class WorkflowCheckpoint:
    """Checkpoint for resuming interrupted workflows."""
    book_project_id: str
    phase: int
    chapter_index: int
    subject_index: int

    # Timestamps
    created_at: str = field(default_factory=lambda: datetime.now().isoformat())
    updated_at: str = field(default_factory=lambda: datetime.now().isoformat())

    # Phase-specific progress
    phase_progress: Dict[str, Any] = field(default_factory=dict)

    # State hash for validation
    state_hash: str = ""

    def update(self, chapter_index: int, subject_index: int) -> None:
        """Update checkpoint position."""
        self.chapter_index = chapter_index
        self.subject_index = subject_index
        self.updated_at = datetime.now().isoformat()

    def set_phase(self, phase: int) -> None:
        """Set current phase."""
        self.phase = phase
        self.chapter_index = 0
        self.subject_index = 0
        self.updated_at = datetime.now().isoformat()

    def compute_hash(self, project: BookProject) -> str:
        """Compute state hash for validation."""
        state_str = f"{project.project_id}:{project.total_subjects}:{len(project.chapters)}"
        self.state_hash = hashlib.md5(state_str.encode()).hexdigest()[:12]
        return self.state_hash

    def to_dict(self) -> Dict[str, Any]:
        return {
            'book_project_id': self.book_project_id,
            'phase': self.phase,
            'chapter_index': self.chapter_index,
            'subject_index': self.subject_index,
            'created_at': self.created_at,
            'updated_at': self.updated_at,
            'phase_progress': self.phase_progress,
            'state_hash': self.state_hash,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'WorkflowCheckpoint':
        return cls(
            book_project_id=data.get('book_project_id', ''),
            phase=data.get('phase', 1),
            chapter_index=data.get('chapter_index', 0),
            subject_index=data.get('subject_index', 0),
            created_at=data.get('created_at', datetime.now().isoformat()),
            updated_at=data.get('updated_at', datetime.now().isoformat()),
            phase_progress=data.get('phase_progress', {}),
            state_hash=data.get('state_hash', ''),
        )

    def save(self, project_dir: Path) -> Path:
        """Save checkpoint to file."""
        path = project_dir / 'checkpoint.json'
        with open(path, 'w', encoding='utf-8') as f:
            json.dump(self.to_dict(), f, indent=2)
        logger.debug(f"Saved checkpoint to {path}")
        return path

    @classmethod
    def load(cls, project_dir: Path) -> Optional['WorkflowCheckpoint']:
        """Load checkpoint from file if exists."""
        path = project_dir / 'checkpoint.json'
        if not path.exists():
            return None
        with open(path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return cls.from_dict(data)


# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def list_book_projects(projects_dir: Path = None) -> List[Dict[str, Any]]:
    """List all book projects."""
    if projects_dir is None:
        projects_dir = DEFAULT_PROJECTS_DIR

    if not projects_dir.exists():
        return []

    projects = []
    for project_path in projects_dir.iterdir():
        if project_path.is_dir() and project_path.name.startswith('BOOK_'):
            project_file = project_path / 'project.json'
            if project_file.exists():
                try:
                    project = BookProject.load(project_file)
                    projects.append({
                        'project_id': project.project_id,
                        'title': project.title,
                        'chapters': len(project.chapters),
                        'phase': project.current_phase,
                        'progress': round(project.progress, 2),
                        'updated_at': project.updated_at,
                    })
                except Exception as e:
                    logger.warning(f"Could not load project {project_path}: {e}")

    # Sort by updated_at descending
    projects.sort(key=lambda p: p['updated_at'], reverse=True)
    return projects


def load_book_project(project_id: str, projects_dir: Path = None) -> Optional[BookProject]:
    """Load a book project by ID."""
    if projects_dir is None:
        projects_dir = DEFAULT_PROJECTS_DIR

    project_path = projects_dir / project_id / 'project.json'
    if not project_path.exists():
        logger.error(f"Project not found: {project_id}")
        return None

    return BookProject.load(project_path)


def delete_book_project(project_id: str, projects_dir: Path = None) -> bool:
    """Delete a book project."""
    if projects_dir is None:
        projects_dir = DEFAULT_PROJECTS_DIR

    project_path = projects_dir / project_id
    if not project_path.exists():
        return False

    import shutil
    shutil.rmtree(project_path)
    logger.info(f"Deleted project: {project_id}")
    return True


# =============================================================================
# EXPORTS
# =============================================================================

__all__ = [
    'BookProject',
    'ChapterProject',
    'SubjectResearch',
    'WorkflowCheckpoint',
    'IdentifiedGap',
    'generate_project_id',
    'list_book_projects',
    'load_book_project',
    'delete_book_project',
    'DEFAULT_PROJECTS_DIR',
]
