"""
Configuration management for Agent Orchestrator.

Loads configuration from environment variables with sensible defaults.
All sensitive values (API keys) must be provided via environment.
"""

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

from dotenv import load_dotenv


class ConfigValidationError(Exception):
    """Raised when configuration validation fails."""

    def __init__(self, errors: list[str]):
        self.errors = errors
        message = f"Configuration validation failed with {len(errors)} error(s):\n"
        message += "\n".join(f"  - {e}" for e in errors)
        super().__init__(message)

# Load .env file if present
load_dotenv()


def _get_env(key: str, default: Optional[str] = None, required: bool = False) -> Optional[str]:
    """Get environment variable with optional default and required check."""
    value = os.environ.get(key, default)
    if required and value is None:
        raise ValueError(f"Required environment variable {key} is not set")
    return value


def _get_env_int(key: str, default: int) -> int:
    """Get environment variable as integer."""
    value = _get_env(key)
    return int(value) if value else default


def _get_env_float(key: str, default: float) -> float:
    """Get environment variable as float."""
    value = _get_env(key)
    return float(value) if value else default


def _get_env_bool(key: str, default: bool) -> bool:
    """Get environment variable as boolean."""
    value = _get_env(key)
    if value is None:
        return default
    return value.lower() in ("true", "1", "yes", "on")


def _get_env_list(key: str, default: list[str]) -> list[str]:
    """Get environment variable as comma-separated list."""
    value = _get_env(key)
    if value is None:
        return default
    return [item.strip() for item in value.split(",") if item.strip()]


@dataclass
class APIKeysConfig:
    """API key configuration - loaded from environment only."""

    anthropic_api_key: Optional[str] = field(
        default_factory=lambda: _get_env("ANTHROPIC_API_KEY")
    )
    openai_api_key: Optional[str] = field(
        default_factory=lambda: _get_env("OPENAI_API_KEY")
    )
    google_api_key: Optional[str] = field(
        default_factory=lambda: _get_env("GOOGLE_API_KEY")
    )

    def has_anthropic(self) -> bool:
        """Check if Anthropic API key is configured."""
        return self.anthropic_api_key is not None

    def has_openai(self) -> bool:
        """Check if OpenAI API key is configured."""
        return self.openai_api_key is not None

    def has_google(self) -> bool:
        """Check if Google API key is configured."""
        return self.google_api_key is not None


@dataclass
class DatabaseConfig:
    """Database configuration."""

    path: Path = field(
        default_factory=lambda: Path(_get_env("DATABASE_PATH", "data/orchestrator.db"))
    )

    def ensure_directory(self) -> None:
        """Ensure database directory exists."""
        self.path.parent.mkdir(parents=True, exist_ok=True)


@dataclass
class AgentBudgetConfig:
    """Budget configuration for a single agent."""

    daily_token_limit: int
    daily_cost_limit: float
    rate_limit_rpm: int = 60

    def validate(self, agent_name: str = "agent") -> list[str]:
        """Validate agent budget configuration."""
        errors = []
        if self.daily_token_limit < 0:
            errors.append(f"{agent_name}.daily_token_limit must be non-negative (got {self.daily_token_limit})")
        if self.daily_cost_limit < 0:
            errors.append(f"{agent_name}.daily_cost_limit must be non-negative (got {self.daily_cost_limit})")
        if self.rate_limit_rpm < 1:
            errors.append(f"{agent_name}.rate_limit_rpm must be at least 1 (got {self.rate_limit_rpm})")
        return errors


@dataclass
class BudgetsConfig:
    """Budget configuration for all agents."""

    # Global budget limits (for simpler configuration)
    daily_budget_usd: float = field(
        default_factory=lambda: _get_env_float("DAILY_BUDGET_USD", 50.0)
    )
    task_budget_usd: float = field(
        default_factory=lambda: _get_env_float("TASK_BUDGET_USD", 5.0)
    )
    warn_threshold_percent: int = field(
        default_factory=lambda: _get_env_int("WARN_THRESHOLD_PERCENT", 80)
    )

    # Per-agent budgets (more granular control)
    claude_code: AgentBudgetConfig = field(
        default_factory=lambda: AgentBudgetConfig(
            daily_token_limit=_get_env_int("CLAUDE_CODE_DAILY_TOKEN_LIMIT", 500_000),
            daily_cost_limit=_get_env_float("CLAUDE_CODE_DAILY_COST_LIMIT", 50.0),
        )
    )
    gemini_cli: AgentBudgetConfig = field(
        default_factory=lambda: AgentBudgetConfig(
            daily_token_limit=_get_env_int("GEMINI_CLI_DAILY_TOKEN_LIMIT", 1_000_000),
            daily_cost_limit=_get_env_float("GEMINI_CLI_DAILY_COST_LIMIT", 0.0),
        )
    )
    codex: AgentBudgetConfig = field(
        default_factory=lambda: AgentBudgetConfig(
            daily_token_limit=_get_env_int("CODEX_DAILY_TOKEN_LIMIT", 200_000),
            daily_cost_limit=_get_env_float("CODEX_DAILY_COST_LIMIT", 20.0),
        )
    )

    def get_budget(self, agent_name: str) -> Optional[AgentBudgetConfig]:
        """Get budget config for an agent by name."""
        budgets = {
            "claude_code": self.claude_code,
            "gemini_cli": self.gemini_cli,
            "codex": self.codex,
        }
        return budgets.get(agent_name)

    def validate(self) -> list[str]:
        """Validate budget configuration."""
        errors = []

        if self.daily_budget_usd <= 0:
            errors.append(f"budgets.daily_budget_usd must be positive (got {self.daily_budget_usd})")

        if self.task_budget_usd <= 0:
            errors.append(f"budgets.task_budget_usd must be positive (got {self.task_budget_usd})")
        elif self.task_budget_usd > self.daily_budget_usd:
            errors.append(
                f"budgets.task_budget_usd ({self.task_budget_usd}) cannot exceed "
                f"daily_budget_usd ({self.daily_budget_usd})"
            )

        if not 0 <= self.warn_threshold_percent <= 100:
            errors.append(
                f"budgets.warn_threshold_percent must be between 0 and 100 "
                f"(got {self.warn_threshold_percent})"
            )

        # Validate per-agent configs
        errors.extend(self.claude_code.validate("budgets.claude_code"))
        errors.extend(self.gemini_cli.validate("budgets.gemini_cli"))
        errors.extend(self.codex.validate("budgets.codex"))

        return errors


@dataclass
class ControlLoopConfig:
    """Control loop configuration."""

    health_check_interval: int = field(
        default_factory=lambda: _get_env_int("HEALTH_CHECK_INTERVAL", 60)
    )
    max_auto_prompt_attempts: int = field(
        default_factory=lambda: _get_env_int("MAX_AUTO_PROMPT_ATTEMPTS", 2)
    )
    stuck_idle_threshold_minutes: int = field(
        default_factory=lambda: _get_env_int("STUCK_IDLE_THRESHOLD_MINUTES", 10)
    )
    stuck_error_threshold_count: int = field(
        default_factory=lambda: _get_env_int("STUCK_ERROR_THRESHOLD_COUNT", 3)
    )
    auto_approve_low_risk: bool = field(
        default_factory=lambda: _get_env_bool("AUTO_APPROVE_LOW_RISK", True)
    )

    def validate(self) -> list[str]:
        """Validate control loop configuration."""
        errors = []

        if self.health_check_interval < 1:
            errors.append(
                f"control_loop.health_check_interval must be at least 1 second "
                f"(got {self.health_check_interval})"
            )

        if self.max_auto_prompt_attempts < 0:
            errors.append(
                f"control_loop.max_auto_prompt_attempts must be non-negative "
                f"(got {self.max_auto_prompt_attempts})"
            )

        if self.stuck_idle_threshold_minutes < 1:
            errors.append(
                f"control_loop.stuck_idle_threshold_minutes must be at least 1 "
                f"(got {self.stuck_idle_threshold_minutes})"
            )

        if self.stuck_error_threshold_count < 1:
            errors.append(
                f"control_loop.stuck_error_threshold_count must be at least 1 "
                f"(got {self.stuck_error_threshold_count})"
            )

        return errors


VALID_RISK_LEVELS = {"low", "medium", "high", "critical"}


@dataclass
class RiskGateConfig:
    """Risk gate configuration."""

    default_risk_level: str = field(
        default_factory=lambda: _get_env("DEFAULT_RISK_LEVEL", "medium")
    )
    approval_timeout_seconds: int = field(
        default_factory=lambda: _get_env_int("APPROVAL_TIMEOUT_SECONDS", 3600)
    )

    def validate(self) -> list[str]:
        """Validate risk gate configuration."""
        errors = []

        if self.default_risk_level.lower() not in VALID_RISK_LEVELS:
            errors.append(
                f"risk_gate.default_risk_level must be one of {VALID_RISK_LEVELS} "
                f"(got '{self.default_risk_level}')"
            )

        if self.approval_timeout_seconds < 1:
            errors.append(
                f"risk_gate.approval_timeout_seconds must be at least 1 "
                f"(got {self.approval_timeout_seconds})"
            )

        return errors


@dataclass
class WebhookConfig:
    """Webhook configuration for async interrupt handler."""

    url: Optional[str] = field(default_factory=lambda: _get_env("WEBHOOK_URL"))
    secret: Optional[str] = field(default_factory=lambda: _get_env("WEBHOOK_SECRET"))
    channel: Optional[str] = field(default_factory=lambda: _get_env("SLACK_CHANNEL"))

    def is_configured(self) -> bool:
        """Check if webhook is properly configured."""
        return self.url is not None


@dataclass
class WorkspaceConfig:
    """Workspace configuration."""

    worktree_base_path: Path = field(
        default_factory=lambda: Path(_get_env("WORKTREE_BASE_PATH", "../"))
    )
    protected_branches: list[str] = field(
        default_factory=lambda: _get_env_list("PROTECTED_BRANCHES", ["main", "master", "prod"])
    )


@dataclass
class CLIAuthConfig:
    """CLI authentication checks."""

    check_enabled: bool = field(
        default_factory=lambda: _get_env_bool("CLI_AUTH_CHECK_ENABLED", True)
    )
    auth_check_prompt: str = field(
        default_factory=lambda: _get_env("CLI_AUTH_CHECK_PROMPT", "Auth check: respond with OK.")
    )
    auth_check_timeout_seconds: int = field(
        default_factory=lambda: _get_env_int("CLI_AUTH_CHECK_TIMEOUT_SECONDS", 30)
    )


@dataclass
class RoutingConfig:
    """Routing and rebalancing configuration."""

    cli_rebalance_enabled: bool = field(
        default_factory=lambda: _get_env_bool("CLI_REBALANCE_ENABLED", True)
    )
    cli_soft_limit_pct: float = field(
        default_factory=lambda: _get_env_float("CLI_SOFT_LIMIT_PCT", 0.85)
    )
    cli_hard_limit_pct: float = field(
        default_factory=lambda: _get_env_float("CLI_HARD_LIMIT_PCT", 0.95)
    )

    def validate(self) -> list[str]:
        """Validate routing configuration."""
        errors = []

        if not 0.0 <= self.cli_soft_limit_pct <= 1.0:
            errors.append(
                f"routing.cli_soft_limit_pct must be between 0.0 and 1.0 "
                f"(got {self.cli_soft_limit_pct})"
            )

        if not 0.0 <= self.cli_hard_limit_pct <= 1.0:
            errors.append(
                f"routing.cli_hard_limit_pct must be between 0.0 and 1.0 "
                f"(got {self.cli_hard_limit_pct})"
            )

        if self.cli_soft_limit_pct >= self.cli_hard_limit_pct:
            errors.append(
                f"routing.cli_soft_limit_pct ({self.cli_soft_limit_pct}) must be less than "
                f"cli_hard_limit_pct ({self.cli_hard_limit_pct})"
            )

        return errors


VALID_LOG_LEVELS = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}


@dataclass
class LoggingConfig:
    """Logging configuration."""

    level: str = field(default_factory=lambda: _get_env("LOG_LEVEL", "INFO"))
    file: Optional[str] = field(default_factory=lambda: _get_env("LOG_FILE"))
    redact_secrets: bool = field(
        default_factory=lambda: _get_env_bool("REDACT_SECRETS", True)
    )

    def validate(self) -> list[str]:
        """Validate logging configuration."""
        errors = []

        if self.level.upper() not in VALID_LOG_LEVELS:
            errors.append(
                f"logging.level must be one of {VALID_LOG_LEVELS} "
                f"(got '{self.level}')"
            )

        return errors


@dataclass
class ObservabilityConfig:
    """Observability tools configuration."""

    ai_observer_url: Optional[str] = field(
        default_factory=lambda: _get_env("AI_OBSERVER_URL")
    )
    token_audit_mcp_url: Optional[str] = field(
        default_factory=lambda: _get_env("TOKEN_AUDIT_MCP_URL")
    )


@dataclass
class Config:
    """
    Main configuration class for Agent Orchestrator.

    Usage:
        config = Config()
        if config.api_keys.has_anthropic():
            # Use Claude

    Validation:
        config = Config()
        config.validate()  # Raises ConfigValidationError if invalid

        # Or check without raising:
        errors = config.validate(raise_on_error=False)
        if errors:
            print(f"Found {len(errors)} config errors")
    """

    api_keys: APIKeysConfig = field(default_factory=APIKeysConfig)
    database: DatabaseConfig = field(default_factory=DatabaseConfig)
    budgets: BudgetsConfig = field(default_factory=BudgetsConfig)
    control_loop: ControlLoopConfig = field(default_factory=ControlLoopConfig)
    risk_gate: RiskGateConfig = field(default_factory=RiskGateConfig)
    webhook: WebhookConfig = field(default_factory=WebhookConfig)
    workspace: WorkspaceConfig = field(default_factory=WorkspaceConfig)
    cli_auth: CLIAuthConfig = field(default_factory=CLIAuthConfig)
    routing: RoutingConfig = field(default_factory=RoutingConfig)
    logging: LoggingConfig = field(default_factory=LoggingConfig)
    observability: ObservabilityConfig = field(default_factory=ObservabilityConfig)

    def __post_init__(self) -> None:
        """Post-initialization setup."""
        # Ensure database directory exists
        self.database.ensure_directory()

    def validate(self, raise_on_error: bool = True) -> list[str]:
        """
        Validate all configuration settings.

        Args:
            raise_on_error: If True (default), raise ConfigValidationError on failure.
                           If False, just return the list of errors.

        Returns:
            List of validation error messages (empty if valid).

        Raises:
            ConfigValidationError: If validation fails and raise_on_error is True.
        """
        errors: list[str] = []

        # Validate each config section that has validation
        errors.extend(self.budgets.validate())
        errors.extend(self.control_loop.validate())
        errors.extend(self.risk_gate.validate())
        errors.extend(self.routing.validate())
        errors.extend(self.logging.validate())

        if errors and raise_on_error:
            raise ConfigValidationError(errors)

        return errors

    def is_valid(self) -> bool:
        """Check if configuration is valid without raising."""
        return len(self.validate(raise_on_error=False)) == 0

    def get_validation_summary(self) -> dict[str, any]:
        """
        Get a summary of configuration validation status.

        Returns:
            Dictionary with validation results.
        """
        errors = self.validate(raise_on_error=False)
        return {
            "valid": len(errors) == 0,
            "error_count": len(errors),
            "errors": errors,
            "sections_validated": [
                "budgets",
                "control_loop",
                "risk_gate",
                "routing",
                "logging",
            ],
        }


# Global config instance (lazy-loaded)
_config: Optional[Config] = None


def get_config(validate: bool = False) -> Config:
    """
    Get the global configuration instance.

    Args:
        validate: If True, validate config and raise on errors.

    Returns:
        The global Config instance.

    Raises:
        ConfigValidationError: If validate=True and config is invalid.
    """
    global _config
    if _config is None:
        _config = Config()
        if validate:
            _config.validate()
    return _config


def reload_config(validate: bool = False) -> Config:
    """
    Reload configuration from environment.

    Args:
        validate: If True, validate config and raise on errors.

    Returns:
        The reloaded Config instance.

    Raises:
        ConfigValidationError: If validate=True and config is invalid.
    """
    global _config
    load_dotenv(override=True)
    _config = Config()
    if validate:
        _config.validate()
    return _config


def validate_config() -> list[str]:
    """
    Validate the current configuration without raising.

    Returns:
        List of validation errors (empty if valid).
    """
    return get_config().validate(raise_on_error=False)
