"""
VotingCoordinator - Multi-agent voting and consensus mechanism.

Provides:
- Voting session management
- Multiple quorum types (simple majority, supermajority, unanimous)
- Tie-breaking strategies
- Confidence-weighted voting
- Audit trail of decisions
"""

import logging
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Callable, Optional

from ..persistence.database import OrchestratorDB
from ..persistence.models import Vote, VotingSession, QuorumType

logger = logging.getLogger(__name__)


class TieBreaker(Enum):
    """Strategies for breaking ties in votes."""

    FIRST_VOTE = "first_vote"  # First option voted for wins
    HIGHEST_CONFIDENCE = "highest_confidence"  # Option with highest confidence wins
    RANDOM = "random"  # Random selection (not recommended)
    NO_DECISION = "no_decision"  # No winner on tie


@dataclass
class VotingResult:
    """Result of a voting session."""

    session_id: str
    topic: str
    winner: Optional[str]
    vote_counts: dict[str, int]
    weighted_counts: dict[str, float]
    total_votes: int
    quorum_met: bool
    quorum_required: int
    quorum_type: str
    threshold_met: bool
    required_threshold: float
    actual_threshold: float
    rationale: str
    tie_broken: bool = False
    tie_breaker_used: Optional[str] = None

    @property
    def has_winner(self) -> bool:
        """Check if there is a clear winner."""
        return self.winner is not None and self.quorum_met and self.threshold_met


class VotingCoordinator:
    """
    Coordinates multi-agent voting sessions.

    Handles creating voting sessions, collecting votes, determining
    consensus, and recording results.
    """

    def __init__(
        self,
        db: OrchestratorDB,
        default_quorum: int = 3,
        default_quorum_type: QuorumType = QuorumType.SIMPLE_MAJORITY,
        default_timeout_minutes: int = 30,
        default_tie_breaker: TieBreaker = TieBreaker.HIGHEST_CONFIDENCE,
    ) -> None:
        """
        Initialize the voting coordinator.

        Args:
            db: Database instance
            default_quorum: Default minimum voters required
            default_quorum_type: Default quorum type
            default_timeout_minutes: Default voting session timeout
            default_tie_breaker: Default tie-breaking strategy
        """
        self.db = db
        self.default_quorum = default_quorum
        self.default_quorum_type = default_quorum_type
        self.default_timeout_minutes = default_timeout_minutes
        self.default_tie_breaker = default_tie_breaker

    def create_session(
        self,
        topic: str,
        created_by: str,
        options: list[str],
        description: Optional[str] = None,
        task_id: Optional[str] = None,
        required_voters: Optional[int] = None,
        quorum_type: Optional[QuorumType] = None,
        supermajority_threshold: float = 0.67,
        timeout_minutes: Optional[int] = None,
    ) -> VotingSession:
        """
        Create a new voting session.

        Args:
            topic: The subject of the vote
            created_by: Agent ID or 'orchestrator' who initiated the vote
            options: List of valid voting options
            description: Detailed description for voters
            task_id: Optional associated task
            required_voters: Minimum voters for valid result
            quorum_type: Type of quorum required
            supermajority_threshold: Threshold for supermajority (default 0.67)
            timeout_minutes: Session timeout

        Returns:
            The created VotingSession
        """
        session_id = self.db.generate_voting_session_id()

        # Apply defaults
        if required_voters is None:
            required_voters = self.default_quorum
        if quorum_type is None:
            quorum_type = self.default_quorum_type
        if timeout_minutes is None:
            timeout_minutes = self.default_timeout_minutes

        deadline = datetime.now() + timedelta(minutes=timeout_minutes)

        # Store options in description as JSON-like format
        full_description = description or ""
        if options:
            full_description += f"\n\nOptions: {', '.join(options)}"

        session = VotingSession(
            id=session_id,
            topic=topic,
            created_by=created_by,
            task_id=task_id,
            description=full_description,
            deadline_at=deadline,
            required_voters=required_voters,
            quorum_type=quorum_type.value,
            supermajority_threshold=supermajority_threshold,
        )

        self.db.create_voting_session(session)
        logger.info(
            f"Created voting session {session_id}: {topic} "
            f"(quorum={required_voters}, type={quorum_type.value})"
        )

        return session

    def cast_vote(
        self,
        session_id: str,
        agent_id: str,
        choice: str,
        confidence: float = 1.0,
        rationale: Optional[str] = None,
    ) -> bool:
        """
        Cast a vote in a voting session.

        Args:
            session_id: The voting session ID
            agent_id: The voting agent's ID
            choice: The option being voted for
            confidence: Confidence level (0.0-1.0)
            rationale: Explanation for the vote

        Returns:
            True if vote was cast successfully
        """
        session = self.db.get_voting_session(session_id)
        if not session:
            logger.error(f"Voting session {session_id} not found")
            return False

        if not session.is_open:
            logger.warning(f"Cannot vote in closed session {session_id}")
            return False

        # Validate confidence
        confidence = max(0.0, min(1.0, confidence))

        vote = Vote(
            session_id=session_id,
            agent_id=agent_id,
            choice=choice,
            confidence=confidence,
            rationale=rationale,
        )

        self.db.cast_vote(vote)
        logger.info(
            f"Agent {agent_id} voted '{choice}' in session {session_id} "
            f"(confidence={confidence})"
        )

        return True

    def get_session_status(self, session_id: str) -> dict[str, Any]:
        """
        Get current status of a voting session.

        Returns dict with session info, votes cast, and current standings.
        """
        session = self.db.get_voting_session(session_id)
        if not session:
            return {"error": "Session not found"}

        votes = self.db.get_votes(session_id)
        vote_counts = self.db.get_vote_counts(session_id)
        weighted_counts = self.db.get_weighted_vote_counts(session_id)

        return {
            "session": session,
            "votes": votes,
            "vote_counts": vote_counts,
            "weighted_counts": weighted_counts,
            "total_votes": len(votes),
            "quorum_required": session.required_voters,
            "quorum_met": len(votes) >= session.required_voters,
            "is_open": session.is_open,
            "deadline": session.deadline_at,
        }

    def tally_votes(
        self,
        session_id: str,
        tie_breaker: Optional[TieBreaker] = None,
        close_session: bool = True,
    ) -> VotingResult:
        """
        Tally votes and determine the winner.

        Args:
            session_id: The voting session ID
            tie_breaker: Strategy for breaking ties
            close_session: Whether to close the session after tallying

        Returns:
            VotingResult with winner and vote details
        """
        if tie_breaker is None:
            tie_breaker = self.default_tie_breaker

        session = self.db.get_voting_session(session_id)
        if not session:
            return VotingResult(
                session_id=session_id,
                topic="Unknown",
                winner=None,
                vote_counts={},
                weighted_counts={},
                total_votes=0,
                quorum_met=False,
                quorum_required=0,
                quorum_type="unknown",
                threshold_met=False,
                required_threshold=0.0,
                actual_threshold=0.0,
                rationale="Session not found",
            )

        votes = self.db.get_votes(session_id)
        vote_counts = self.db.get_vote_counts(session_id)
        weighted_counts = self.db.get_weighted_vote_counts(session_id)
        total_votes = len(votes)

        # Check quorum
        quorum_met = total_votes >= session.required_voters

        # Determine threshold based on quorum type
        quorum_type = session.quorum_type
        required_threshold = self._get_required_threshold(
            quorum_type, session.supermajority_threshold
        )

        # Find the winner and check threshold
        winner, actual_threshold, tie_broken, tie_breaker_used = self._determine_winner(
            votes, vote_counts, weighted_counts, required_threshold, tie_breaker
        )

        # Check if threshold is met
        threshold_met = actual_threshold >= required_threshold if winner else False

        # Generate rationale
        rationale = self._generate_rationale(
            session,
            winner,
            vote_counts,
            total_votes,
            quorum_met,
            threshold_met,
            tie_broken,
        )

        result = VotingResult(
            session_id=session_id,
            topic=session.topic,
            winner=winner if quorum_met and threshold_met else None,
            vote_counts=vote_counts,
            weighted_counts=weighted_counts,
            total_votes=total_votes,
            quorum_met=quorum_met,
            quorum_required=session.required_voters,
            quorum_type=quorum_type,
            threshold_met=threshold_met,
            required_threshold=required_threshold,
            actual_threshold=actual_threshold,
            rationale=rationale,
            tie_broken=tie_broken,
            tie_breaker_used=tie_breaker_used.value if tie_breaker_used else None,
        )

        # Close session if requested
        if close_session and session.is_open:
            self.db.close_voting_session(
                session_id,
                result.winner,
                result.rationale,
            )
            logger.info(f"Closed voting session {session_id}: winner={result.winner}")

        return result

    def check_expired_sessions(self) -> list[VotingResult]:
        """
        Check for and tally expired voting sessions.

        Returns list of results for newly closed sessions.
        """
        results = []
        open_sessions = self.db.get_open_voting_sessions()

        now = datetime.now()
        for session in open_sessions:
            if session.deadline_at and session.deadline_at <= now:
                logger.info(f"Voting session {session.id} has expired, tallying votes")
                result = self.tally_votes(session.id, close_session=True)
                results.append(result)

        return results

    def _get_required_threshold(
        self, quorum_type: str, supermajority_threshold: float
    ) -> float:
        """Get the required vote threshold for a quorum type."""
        if quorum_type == QuorumType.SIMPLE_MAJORITY.value:
            return 0.5  # >50%
        elif quorum_type == QuorumType.SUPERMAJORITY.value:
            return supermajority_threshold
        elif quorum_type == QuorumType.UNANIMOUS.value:
            return 1.0  # 100%
        else:
            return 0.5  # Default to simple majority

    def _determine_winner(
        self,
        votes: list[Vote],
        vote_counts: dict[str, int],
        weighted_counts: dict[str, float],
        required_threshold: float,
        tie_breaker: TieBreaker,
    ) -> tuple[Optional[str], float, bool, Optional[TieBreaker]]:
        """
        Determine the winner from vote counts.

        Returns (winner, actual_threshold, tie_broken, tie_breaker_used)
        """
        if not vote_counts:
            return None, 0.0, False, None

        total_votes = sum(vote_counts.values())
        if total_votes == 0:
            return None, 0.0, False, None

        # Find top options
        max_votes = max(vote_counts.values())
        top_options = [opt for opt, count in vote_counts.items() if count == max_votes]

        # Check for tie
        if len(top_options) > 1:
            winner, used_breaker = self._break_tie(
                top_options, votes, weighted_counts, tie_breaker
            )
            if winner:
                actual_threshold = vote_counts[winner] / total_votes
                return winner, actual_threshold, True, used_breaker
            else:
                # No decision on tie
                return None, max_votes / total_votes, True, TieBreaker.NO_DECISION
        else:
            winner = top_options[0]
            actual_threshold = max_votes / total_votes
            return winner, actual_threshold, False, None

    def _break_tie(
        self,
        tied_options: list[str],
        votes: list[Vote],
        weighted_counts: dict[str, float],
        tie_breaker: TieBreaker,
    ) -> tuple[Optional[str], Optional[TieBreaker]]:
        """
        Break a tie between options.

        Returns (winner, tie_breaker_used)
        """
        if tie_breaker == TieBreaker.FIRST_VOTE:
            # Find which option was voted for first
            for vote in votes:
                if vote.choice in tied_options:
                    return vote.choice, TieBreaker.FIRST_VOTE
            return None, None

        elif tie_breaker == TieBreaker.HIGHEST_CONFIDENCE:
            # Choose option with highest total confidence
            best_option = max(
                tied_options,
                key=lambda opt: weighted_counts.get(opt, 0),
            )
            return best_option, TieBreaker.HIGHEST_CONFIDENCE

        elif tie_breaker == TieBreaker.RANDOM:
            import random
            return random.choice(tied_options), TieBreaker.RANDOM

        elif tie_breaker == TieBreaker.NO_DECISION:
            return None, TieBreaker.NO_DECISION

        return None, None

    def _generate_rationale(
        self,
        session: VotingSession,
        winner: Optional[str],
        vote_counts: dict[str, int],
        total_votes: int,
        quorum_met: bool,
        threshold_met: bool,
        tie_broken: bool,
    ) -> str:
        """Generate a human-readable rationale for the voting result."""
        parts = []

        # Vote summary
        if vote_counts:
            vote_summary = ", ".join(
                f"{opt}: {count}" for opt, count in vote_counts.items()
            )
            parts.append(f"Votes received: {vote_summary} (total: {total_votes})")
        else:
            parts.append("No votes received")

        # Quorum status
        if quorum_met:
            parts.append(
                f"Quorum met ({total_votes}/{session.required_voters} required)"
            )
        else:
            parts.append(
                f"Quorum NOT met ({total_votes}/{session.required_voters} required)"
            )

        # Threshold status
        if winner and threshold_met:
            parts.append(f"Winner: '{winner}' met {session.quorum_type} threshold")
        elif winner and not threshold_met:
            parts.append(f"'{winner}' did not meet {session.quorum_type} threshold")

        # Tie information
        if tie_broken:
            parts.append("Tie was broken using configured strategy")

        # Final outcome
        if winner and quorum_met and threshold_met:
            parts.append(f"Decision: {winner}")
        else:
            parts.append("No decision reached")

        return ". ".join(parts)

    def get_pending_votes_for_agent(
        self, agent_id: str
    ) -> list[VotingSession]:
        """
        Get open voting sessions where an agent hasn't voted.

        Returns list of sessions awaiting this agent's vote.
        """
        open_sessions = self.db.get_open_voting_sessions()
        pending = []

        for session in open_sessions:
            if not self.db.has_agent_voted(session.id, agent_id):
                pending.append(session)

        return pending
