#!/usr/bin/env python3
"""
Pipeline Versioning System

Explicit versioning for pipeline artifacts to prevent "silent drift" where
search behavior changes and nobody can explain why.

Versioned Artifacts:
- Chunk Sets: (doc_id, chunk_strategy, chunk_params) → version_id
- Embedding Sets: (chunk_set_id, embedding_model, dim, provider) → version_id
- Graph Extraction Runs: (extractor, entity_types, params) → version_id

Usage:
    from pipeline_versioning import VersionRegistry, ChunkSetVersion

    registry = VersionRegistry()

    # Create versioned chunk set
    chunk_version = registry.create_chunk_set_version(
        doc_id="DOC_001",
        strategy="semantic",
        params={"threshold": 0.7}
    )

    # Create versioned embedding set
    embed_version = registry.create_embedding_set_version(
        chunk_set_version=chunk_version.version_id,
        model="text-embedding-3-small",
        dimension=1536,
        provider="openai"
    )

    # Query versions
    versions = registry.get_versions_for_document("DOC_001")
"""

import hashlib
import json
import logging
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any
import sqlite3

logger = logging.getLogger(__name__)

# Current pipeline version
PIPELINE_VERSION = "3.3.0"


# =============================================================================
# VERSION DATA STRUCTURES
# =============================================================================

@dataclass
class ChunkSetVersion:
    """
    Version identifier for a set of chunks from a document.

    Chunks are versioned per (doc_id, chunk_strategy, chunk_params).
    This allows multiple chunking strategies to coexist and enables
    safe re-chunking without losing the old version.
    """
    doc_id: str
    strategy: str  # fixed, semantic, hierarchical
    params: Dict[str, Any]  # Chunking parameters
    pipeline_version: str = PIPELINE_VERSION
    created_at: str = field(default_factory=lambda: datetime.now().isoformat())
    chunk_count: int = 0
    is_active: bool = True  # Currently used version
    notes: str = ""

    @property
    def params_hash(self) -> str:
        """Hash of chunking parameters for version ID."""
        param_str = json.dumps(self.params, sort_keys=True)
        return hashlib.md5(param_str.encode()).hexdigest()[:8]

    @property
    def version_id(self) -> str:
        """Unique version identifier."""
        return f"{self.doc_id}_{self.strategy}_{self.params_hash}_{self.pipeline_version}"

    def to_dict(self) -> Dict:
        result = asdict(self)
        result['version_id'] = self.version_id
        result['params_hash'] = self.params_hash
        return result


@dataclass
class EmbeddingSetVersion:
    """
    Version identifier for a set of embeddings.

    Embeddings are versioned per (chunk_set_version, model, dimension, provider).
    This allows multiple embedding models to coexist and enables
    safe re-embedding without losing the old version.
    """
    chunk_set_version_id: str
    model: str  # text-embedding-3-small, nomic-embed-text, etc.
    dimension: int
    provider: str  # openai, local, huggingface
    pipeline_version: str = PIPELINE_VERSION
    created_at: str = field(default_factory=lambda: datetime.now().isoformat())
    embedding_count: int = 0
    is_active: bool = True
    notes: str = ""

    @property
    def version_id(self) -> str:
        """Unique version identifier."""
        return f"{self.chunk_set_version_id}_{self.model}_{self.dimension}_{self.provider}"

    def to_dict(self) -> Dict:
        result = asdict(self)
        result['version_id'] = self.version_id
        return result


@dataclass
class GraphExtractionVersion:
    """
    Version identifier for a knowledge graph extraction run.

    Graph extractions are versioned per (extractor, entity_types, relation_config).
    This allows concept merges to be reversible by keeping extraction history.
    """
    extractor: str  # gliner, openai, hybrid, llama
    entity_types: List[str]
    extract_relations: bool = False
    relation_backend: Optional[str] = None  # For hybrid mode
    confidence_threshold: float = 0.0
    pipeline_version: str = PIPELINE_VERSION
    created_at: str = field(default_factory=lambda: datetime.now().isoformat())
    entities_extracted: int = 0
    relations_extracted: int = 0
    is_active: bool = True
    notes: str = ""

    @property
    def types_hash(self) -> str:
        """Hash of entity types for version ID."""
        types_str = "_".join(sorted(self.entity_types))
        return hashlib.md5(types_str.encode()).hexdigest()[:8]

    @property
    def version_id(self) -> str:
        """Unique version identifier."""
        rel_flag = "rel" if self.extract_relations else "norel"
        return f"{self.extractor}_{self.types_hash}_{rel_flag}_{self.pipeline_version}"

    def to_dict(self) -> Dict:
        result = asdict(self)
        result['version_id'] = self.version_id
        result['types_hash'] = self.types_hash
        return result


# =============================================================================
# VERSION REGISTRY
# =============================================================================

class VersionRegistry:
    """
    Registry for tracking pipeline artifact versions.

    Provides version management for chunks, embeddings, and graph extractions,
    enabling safe migrations and rollbacks.
    """

    def __init__(self, db_path: Optional[Path] = None):
        """
        Initialize version registry.

        Args:
            db_path: Path to SQLite database for version storage.
                    If None, uses in-memory storage.
        """
        self.db_path = db_path
        self._init_storage()

    def _init_storage(self):
        """Initialize storage backend."""
        if self.db_path:
            self.conn = sqlite3.connect(str(self.db_path))
        else:
            self.conn = sqlite3.connect(":memory:")

        self.conn.row_factory = sqlite3.Row
        self._create_tables()

    def _create_tables(self):
        """Create version tables."""
        cursor = self.conn.cursor()

        cursor.execute("""
            CREATE TABLE IF NOT EXISTS chunk_set_versions (
                version_id TEXT PRIMARY KEY,
                doc_id TEXT NOT NULL,
                strategy TEXT NOT NULL,
                params TEXT NOT NULL,
                params_hash TEXT NOT NULL,
                pipeline_version TEXT NOT NULL,
                created_at TEXT NOT NULL,
                chunk_count INTEGER DEFAULT 0,
                is_active INTEGER DEFAULT 1,
                notes TEXT
            )
        """)

        cursor.execute("""
            CREATE TABLE IF NOT EXISTS embedding_set_versions (
                version_id TEXT PRIMARY KEY,
                chunk_set_version_id TEXT NOT NULL,
                model TEXT NOT NULL,
                dimension INTEGER NOT NULL,
                provider TEXT NOT NULL,
                pipeline_version TEXT NOT NULL,
                created_at TEXT NOT NULL,
                embedding_count INTEGER DEFAULT 0,
                is_active INTEGER DEFAULT 1,
                notes TEXT,
                FOREIGN KEY (chunk_set_version_id) REFERENCES chunk_set_versions(version_id)
            )
        """)

        cursor.execute("""
            CREATE TABLE IF NOT EXISTS graph_extraction_versions (
                version_id TEXT PRIMARY KEY,
                extractor TEXT NOT NULL,
                entity_types TEXT NOT NULL,
                types_hash TEXT NOT NULL,
                extract_relations INTEGER DEFAULT 0,
                relation_backend TEXT,
                confidence_threshold REAL DEFAULT 0.0,
                pipeline_version TEXT NOT NULL,
                created_at TEXT NOT NULL,
                entities_extracted INTEGER DEFAULT 0,
                relations_extracted INTEGER DEFAULT 0,
                is_active INTEGER DEFAULT 1,
                notes TEXT
            )
        """)

        # Indexes
        cursor.execute("CREATE INDEX IF NOT EXISTS idx_chunk_doc ON chunk_set_versions(doc_id)")
        cursor.execute("CREATE INDEX IF NOT EXISTS idx_chunk_active ON chunk_set_versions(is_active)")
        cursor.execute("CREATE INDEX IF NOT EXISTS idx_embed_chunk ON embedding_set_versions(chunk_set_version_id)")
        cursor.execute("CREATE INDEX IF NOT EXISTS idx_embed_active ON embedding_set_versions(is_active)")

        self.conn.commit()

    # -------------------------------------------------------------------------
    # Chunk Set Versions
    # -------------------------------------------------------------------------

    def create_chunk_set_version(
        self,
        doc_id: str,
        strategy: str,
        params: Dict[str, Any],
        chunk_count: int = 0,
        notes: str = ""
    ) -> ChunkSetVersion:
        """Create a new chunk set version."""
        version = ChunkSetVersion(
            doc_id=doc_id,
            strategy=strategy,
            params=params,
            chunk_count=chunk_count,
            notes=notes
        )

        cursor = self.conn.cursor()
        cursor.execute("""
            INSERT OR REPLACE INTO chunk_set_versions
            (version_id, doc_id, strategy, params, params_hash, pipeline_version,
             created_at, chunk_count, is_active, notes)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            version.version_id,
            version.doc_id,
            version.strategy,
            json.dumps(version.params),
            version.params_hash,
            version.pipeline_version,
            version.created_at,
            version.chunk_count,
            1 if version.is_active else 0,
            version.notes
        ))
        self.conn.commit()

        logger.info(f"Created chunk set version: {version.version_id}")
        return version

    def get_chunk_set_version(self, version_id: str) -> Optional[ChunkSetVersion]:
        """Get chunk set version by ID."""
        cursor = self.conn.cursor()
        cursor.execute("SELECT * FROM chunk_set_versions WHERE version_id = ?", (version_id,))
        row = cursor.fetchone()

        if row:
            return ChunkSetVersion(
                doc_id=row['doc_id'],
                strategy=row['strategy'],
                params=json.loads(row['params']),
                pipeline_version=row['pipeline_version'],
                created_at=row['created_at'],
                chunk_count=row['chunk_count'],
                is_active=bool(row['is_active']),
                notes=row['notes'] or ""
            )
        return None

    def get_active_chunk_version(self, doc_id: str) -> Optional[ChunkSetVersion]:
        """Get active chunk set version for a document."""
        cursor = self.conn.cursor()
        cursor.execute("""
            SELECT * FROM chunk_set_versions
            WHERE doc_id = ? AND is_active = 1
            ORDER BY created_at DESC LIMIT 1
        """, (doc_id,))
        row = cursor.fetchone()

        if row:
            return ChunkSetVersion(
                doc_id=row['doc_id'],
                strategy=row['strategy'],
                params=json.loads(row['params']),
                pipeline_version=row['pipeline_version'],
                created_at=row['created_at'],
                chunk_count=row['chunk_count'],
                is_active=True,
                notes=row['notes'] or ""
            )
        return None

    def get_chunk_versions_for_document(self, doc_id: str) -> List[ChunkSetVersion]:
        """Get all chunk set versions for a document."""
        cursor = self.conn.cursor()
        cursor.execute("""
            SELECT * FROM chunk_set_versions
            WHERE doc_id = ?
            ORDER BY created_at DESC
        """, (doc_id,))

        versions = []
        for row in cursor.fetchall():
            versions.append(ChunkSetVersion(
                doc_id=row['doc_id'],
                strategy=row['strategy'],
                params=json.loads(row['params']),
                pipeline_version=row['pipeline_version'],
                created_at=row['created_at'],
                chunk_count=row['chunk_count'],
                is_active=bool(row['is_active']),
                notes=row['notes'] or ""
            ))
        return versions

    def deactivate_chunk_version(self, version_id: str):
        """Deactivate a chunk set version."""
        cursor = self.conn.cursor()
        cursor.execute(
            "UPDATE chunk_set_versions SET is_active = 0 WHERE version_id = ?",
            (version_id,)
        )
        self.conn.commit()

    # -------------------------------------------------------------------------
    # Embedding Set Versions
    # -------------------------------------------------------------------------

    def create_embedding_set_version(
        self,
        chunk_set_version_id: str,
        model: str,
        dimension: int,
        provider: str,
        embedding_count: int = 0,
        notes: str = ""
    ) -> EmbeddingSetVersion:
        """Create a new embedding set version."""
        version = EmbeddingSetVersion(
            chunk_set_version_id=chunk_set_version_id,
            model=model,
            dimension=dimension,
            provider=provider,
            embedding_count=embedding_count,
            notes=notes
        )

        cursor = self.conn.cursor()
        cursor.execute("""
            INSERT OR REPLACE INTO embedding_set_versions
            (version_id, chunk_set_version_id, model, dimension, provider,
             pipeline_version, created_at, embedding_count, is_active, notes)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            version.version_id,
            version.chunk_set_version_id,
            version.model,
            version.dimension,
            version.provider,
            version.pipeline_version,
            version.created_at,
            version.embedding_count,
            1 if version.is_active else 0,
            version.notes
        ))
        self.conn.commit()

        logger.info(f"Created embedding set version: {version.version_id}")
        return version

    def get_embedding_versions_for_chunks(self, chunk_set_version_id: str) -> List[EmbeddingSetVersion]:
        """Get all embedding versions for a chunk set."""
        cursor = self.conn.cursor()
        cursor.execute("""
            SELECT * FROM embedding_set_versions
            WHERE chunk_set_version_id = ?
            ORDER BY created_at DESC
        """, (chunk_set_version_id,))

        versions = []
        for row in cursor.fetchall():
            versions.append(EmbeddingSetVersion(
                chunk_set_version_id=row['chunk_set_version_id'],
                model=row['model'],
                dimension=row['dimension'],
                provider=row['provider'],
                pipeline_version=row['pipeline_version'],
                created_at=row['created_at'],
                embedding_count=row['embedding_count'],
                is_active=bool(row['is_active']),
                notes=row['notes'] or ""
            ))
        return versions

    # -------------------------------------------------------------------------
    # Graph Extraction Versions
    # -------------------------------------------------------------------------

    def create_graph_extraction_version(
        self,
        extractor: str,
        entity_types: List[str],
        extract_relations: bool = False,
        relation_backend: Optional[str] = None,
        confidence_threshold: float = 0.0,
        entities_extracted: int = 0,
        relations_extracted: int = 0,
        notes: str = ""
    ) -> GraphExtractionVersion:
        """Create a new graph extraction version."""
        version = GraphExtractionVersion(
            extractor=extractor,
            entity_types=entity_types,
            extract_relations=extract_relations,
            relation_backend=relation_backend,
            confidence_threshold=confidence_threshold,
            entities_extracted=entities_extracted,
            relations_extracted=relations_extracted,
            notes=notes
        )

        cursor = self.conn.cursor()
        cursor.execute("""
            INSERT OR REPLACE INTO graph_extraction_versions
            (version_id, extractor, entity_types, types_hash, extract_relations,
             relation_backend, confidence_threshold, pipeline_version, created_at,
             entities_extracted, relations_extracted, is_active, notes)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            version.version_id,
            version.extractor,
            json.dumps(version.entity_types),
            version.types_hash,
            1 if version.extract_relations else 0,
            version.relation_backend,
            version.confidence_threshold,
            version.pipeline_version,
            version.created_at,
            version.entities_extracted,
            version.relations_extracted,
            1 if version.is_active else 0,
            version.notes
        ))
        self.conn.commit()

        logger.info(f"Created graph extraction version: {version.version_id}")
        return version

    # -------------------------------------------------------------------------
    # Version Comparison
    # -------------------------------------------------------------------------

    def compare_versions(
        self,
        version_id_1: str,
        version_id_2: str
    ) -> Dict[str, Any]:
        """Compare two versions and show differences."""
        # Determine version type and fetch both
        v1 = self.get_chunk_set_version(version_id_1)
        v2 = self.get_chunk_set_version(version_id_2)

        if not v1 or not v2:
            return {'error': 'One or both versions not found'}

        comparison = {
            'version_1': version_id_1,
            'version_2': version_id_2,
            'differences': {},
        }

        # Compare fields
        for field_name in ['strategy', 'params', 'pipeline_version', 'chunk_count']:
            val1 = getattr(v1, field_name)
            val2 = getattr(v2, field_name)
            if val1 != val2:
                comparison['differences'][field_name] = {
                    'v1': val1,
                    'v2': val2
                }

        return comparison

    # -------------------------------------------------------------------------
    # Migration Support
    # -------------------------------------------------------------------------

    def plan_migration(
        self,
        from_version_id: str,
        to_strategy: str,
        to_params: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        Plan a migration from one chunking strategy to another.

        Returns a plan that can be reviewed before execution.
        """
        current = self.get_chunk_set_version(from_version_id)
        if not current:
            return {'error': f'Version not found: {from_version_id}'}

        plan = {
            'from_version': from_version_id,
            'from_strategy': current.strategy,
            'from_params': current.params,
            'to_strategy': to_strategy,
            'to_params': to_params,
            'affected_chunks': current.chunk_count,
            'actions': [
                f"Create new chunk set with strategy '{to_strategy}'",
                "Re-chunk document with new parameters",
                f"Deactivate old version ({from_version_id})",
                "Update embeddings for new chunks",
            ],
            'reversible': True,
            'notes': "Old chunks will be preserved but marked inactive"
        }

        return plan


# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def get_default_registry() -> VersionRegistry:
    """Get the default version registry."""
    from config import BASE_DIR
    db_path = BASE_DIR / 'data' / 'pipeline_versions.db'
    db_path.parent.mkdir(parents=True, exist_ok=True)
    return VersionRegistry(db_path)


def record_chunking_version(
    doc_id: str,
    strategy: str,
    params: Dict[str, Any],
    chunk_count: int
) -> ChunkSetVersion:
    """Convenience function to record a chunking operation."""
    registry = get_default_registry()
    return registry.create_chunk_set_version(
        doc_id=doc_id,
        strategy=strategy,
        params=params,
        chunk_count=chunk_count
    )


def record_embedding_version(
    chunk_set_version_id: str,
    model: str,
    dimension: int,
    provider: str,
    embedding_count: int
) -> EmbeddingSetVersion:
    """Convenience function to record an embedding operation."""
    registry = get_default_registry()
    return registry.create_embedding_set_version(
        chunk_set_version_id=chunk_set_version_id,
        model=model,
        dimension=dimension,
        provider=provider,
        embedding_count=embedding_count
    )


# =============================================================================
# CLI
# =============================================================================

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Pipeline versioning utilities')
    parser.add_argument('--list-chunk-versions', type=str, metavar='DOC_ID',
                       help='List chunk versions for a document')
    parser.add_argument('--compare', nargs=2, metavar=('V1', 'V2'),
                       help='Compare two versions')
    parser.add_argument('--format', choices=['text', 'json'], default='text')

    args = parser.parse_args()

    registry = get_default_registry()

    if args.list_chunk_versions:
        versions = registry.get_chunk_versions_for_document(args.list_chunk_versions)
        if args.format == 'json':
            print(json.dumps([v.to_dict() for v in versions], indent=2))
        else:
            print(f"\nChunk versions for {args.list_chunk_versions}:")
            for v in versions:
                status = "ACTIVE" if v.is_active else "inactive"
                print(f"  {v.version_id} [{status}]")
                print(f"    Strategy: {v.strategy}, Chunks: {v.chunk_count}")
                print(f"    Created: {v.created_at}")

    if args.compare:
        comparison = registry.compare_versions(args.compare[0], args.compare[1])
        print(json.dumps(comparison, indent=2))
