#!/usr/bin/env python3
"""
Test script for the domain analyzer.

Verifies that:
- "Write a short story" correctly identifies as CREATIVE_WRITING
- Technical tasks identify correctly
- The agent template registry returns appropriate agents
"""

import sys
sys.path.insert(0, 'src')

from agent_orchestrator.intelligence.domain_analyzer import (
    analyze_task,
    TaskDomain,
    is_creative_task,
    is_technical_task,
)
from agent_orchestrator.agents.templates import (
    select_agents_for_task,
    compose_team_for_task,
    list_templates,
)


def test_creative_writing():
    """Test that creative writing tasks are correctly identified."""
    print("=" * 60)
    print("TESTING CREATIVE WRITING DETECTION")
    print("=" * 60)

    test_cases = [
        "Write a short story about a robot learning to love",
        "Create a poem about the ocean",
        "Write a novel chapter about a detective",
        "Compose a short fiction piece about time travel",
        "Help me write a creative story",
    ]

    all_passed = True
    for task in test_cases:
        analysis = analyze_task(task)
        passed = analysis.primary_domain == TaskDomain.CREATIVE_WRITING

        status = "PASS" if passed else "FAIL"
        print(f"\n[{status}] Task: {task}")
        print(f"  Domain: {analysis.primary_domain.value}")
        print(f"  Confidence: {analysis.primary_confidence:.0%}")

        if not passed:
            all_passed = False
            print(f"  Expected: {TaskDomain.CREATIVE_WRITING.value}")

    return all_passed


def test_technical_coding():
    """Test that technical tasks are correctly identified."""
    print("\n" + "=" * 60)
    print("TESTING TECHNICAL CODING DETECTION")
    print("=" * 60)

    test_cases = [
        "Implement a REST API endpoint for user authentication",
        "Fix the bug in the login function",
        "Write unit tests for the payment module",
        "Refactor the database queries for better performance",
        "Add a React component for the dashboard",
    ]

    all_passed = True
    for task in test_cases:
        analysis = analyze_task(task)
        passed = analysis.primary_domain == TaskDomain.TECHNICAL_CODING

        status = "PASS" if passed else "FAIL"
        print(f"\n[{status}] Task: {task}")
        print(f"  Domain: {analysis.primary_domain.value}")
        print(f"  Confidence: {analysis.primary_confidence:.0%}")

        if not passed:
            all_passed = False
            print(f"  Expected: {TaskDomain.TECHNICAL_CODING.value}")

    return all_passed


def test_negative_cases():
    """Test that tasks don't get misclassified."""
    print("\n" + "=" * 60)
    print("TESTING NEGATIVE CASES (no misclassification)")
    print("=" * 60)

    # These should NOT trigger test-engineer
    test_cases = [
        ("Write a short story about a test pilot", TaskDomain.CREATIVE_WRITING),
        ("Create a narrative about testing limits", TaskDomain.CREATIVE_WRITING),
        ("Story about a character taking a test", TaskDomain.CREATIVE_WRITING),
    ]

    all_passed = True
    for task, expected_domain in test_cases:
        analysis = analyze_task(task)
        passed = analysis.primary_domain == expected_domain

        status = "PASS" if passed else "FAIL"
        print(f"\n[{status}] Task: {task}")
        print(f"  Domain: {analysis.primary_domain.value}")
        print(f"  Confidence: {analysis.primary_confidence:.0%}")

        if not passed:
            all_passed = False
            print(f"  Expected: {expected_domain.value}")
            print(f"  This would have previously triggered test-engineer incorrectly!")

    return all_passed


def test_agent_selection():
    """Test that the right agents are selected for tasks."""
    print("\n" + "=" * 60)
    print("TESTING AGENT SELECTION")
    print("=" * 60)

    test_cases = [
        ("Write a short story about love", "story-writer"),
        ("Implement a REST API", "backend-specialist"),
        ("Write React components", "frontend-specialist"),
        ("Research market trends", "researcher"),
    ]

    all_passed = True
    for task, expected_template in test_cases:
        analysis = analyze_task(task)
        matches = select_agents_for_task(analysis)

        if matches:
            best_match = matches[0].template.id
            passed = expected_template in best_match or best_match == expected_template
        else:
            passed = False
            best_match = "none"

        status = "PASS" if passed else "FAIL"
        print(f"\n[{status}] Task: {task}")
        print(f"  Domain: {analysis.primary_domain.value}")
        print(f"  Selected Agent: {best_match}")

        if not passed:
            all_passed = False
            print(f"  Expected: {expected_template}")

    return all_passed


def test_team_composition():
    """Test team composition for complex tasks."""
    print("\n" + "=" * 60)
    print("TESTING TEAM COMPOSITION")
    print("=" * 60)

    task = "Write a comprehensive short story with proper editing"
    analysis = analyze_task(task)
    team = compose_team_for_task(analysis)

    print(f"\nTask: {task}")
    print(f"Domain: {analysis.primary_domain.value}")
    print(f"Complexity: {analysis.complexity}")
    print(f"\nTeam Composition:")
    print(f"  Lead: {team.lead.name} ({team.lead.id})")

    for specialist in team.specialists:
        print(f"  Specialist: {specialist.name} ({specialist.id})")

    for reviewer in team.reviewers:
        print(f"  Reviewer: {reviewer.name} ({reviewer.id})")

    return team.lead.id == "story-writer"


def list_all_templates():
    """List all available agent templates."""
    print("\n" + "=" * 60)
    print("AVAILABLE AGENT TEMPLATES")
    print("=" * 60)

    templates = list_templates()
    by_domain = {}

    for template in templates:
        domain = template.domain.value
        if domain not in by_domain:
            by_domain[domain] = []
        by_domain[domain].append(template)

    for domain, templates_list in by_domain.items():
        print(f"\n{domain.replace('_', ' ').title()}:")
        for t in templates_list:
            print(f"  - {t.id}: {t.description[:50]}...")


def main():
    """Run all tests."""
    print("\n" + "=" * 60)
    print("DOMAIN ANALYZER AND AGENT SELECTION TESTS")
    print("=" * 60)

    results = {
        "Creative Writing Detection": test_creative_writing(),
        "Technical Coding Detection": test_technical_coding(),
        "Negative Cases": test_negative_cases(),
        "Agent Selection": test_agent_selection(),
        "Team Composition": test_team_composition(),
    }

    list_all_templates()

    print("\n" + "=" * 60)
    print("TEST SUMMARY")
    print("=" * 60)

    all_passed = True
    for test_name, passed in results.items():
        status = "PASS" if passed else "FAIL"
        print(f"  [{status}] {test_name}")
        if not passed:
            all_passed = False

    print("\n" + "=" * 60)
    if all_passed:
        print("ALL TESTS PASSED!")
    else:
        print("SOME TESTS FAILED!")
    print("=" * 60)

    return 0 if all_passed else 1


if __name__ == "__main__":
    sys.exit(main())
