"""
Model Restriction Service

This module provides centralized management of model usage restrictions
based on environment variables. It allows organizations to limit which
models can be used from each provider for cost control, compliance, or
standardization purposes.

Environment Variables:
- OPENAI_ALLOWED_MODELS: Comma-separated list of allowed OpenAI models
- GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models
- XAI_ALLOWED_MODELS: Comma-separated list of allowed X.AI GROK models
- OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models
- DIAL_ALLOWED_MODELS: Comma-separated list of allowed DIAL models
- BLOCKED_MODELS: Comma-separated list of models to block globally (overrides allows)
- DISABLED_MODEL_PATTERNS: Comma-separated patterns to block (e.g., "claude,anthropic")

Example:
    OPENAI_ALLOWED_MODELS=o3-mini,o4-mini
    GOOGLE_ALLOWED_MODELS=flash
    XAI_ALLOWED_MODELS=grok-3,grok-3-fast
    OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral
    
    # Block specific models
    BLOCKED_MODELS=gpt-4,claude-opus
    
    # Block all models matching patterns
    DISABLED_MODEL_PATTERNS=claude,anthropic,gpt-3
"""

import logging
import os
from typing import Optional

from providers.base import ProviderType

logger = logging.getLogger(__name__)


class ModelRestrictionService:
    """
    Centralized service for managing model usage restrictions.

    This service:
    1. Loads restrictions from environment variables at startup
    2. Validates restrictions against known models
    3. Provides a simple interface to check if a model is allowed
    """

    # Environment variable names
    ENV_VARS = {
        ProviderType.OPENAI: "OPENAI_ALLOWED_MODELS",
        ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS",
        ProviderType.XAI: "XAI_ALLOWED_MODELS",
        ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS",
        ProviderType.DIAL: "DIAL_ALLOWED_MODELS",
    }

    def __init__(self):
        """Initialize the restriction service by loading from environment."""
        self.restrictions: dict[ProviderType, set[str]] = {}
        self.blocked_models: set[str] = set()
        self.disabled_patterns: list[str] = []
        self._load_from_env()
        self._load_blocked_models()
        self._load_disabled_patterns()

    def _load_from_env(self) -> None:
        """Load restrictions from environment variables."""
        for provider_type, env_var in self.ENV_VARS.items():
            env_value = os.getenv(env_var)

            if env_value is None or env_value == "":
                # Not set or empty - no restrictions (allow all models)
                logger.debug(f"{env_var} not set or empty - all {provider_type.value} models allowed")
                continue

            # Parse comma-separated list
            models = set()
            for model in env_value.split(","):
                cleaned = model.strip().lower()
                if cleaned:
                    models.add(cleaned)

            if models:
                self.restrictions[provider_type] = models
                logger.info(f"{provider_type.value} allowed models: {sorted(models)}")
            else:
                # All entries were empty after cleaning - treat as no restrictions
                logger.debug(f"{env_var} contains only whitespace - all {provider_type.value} models allowed")
    
    def _load_blocked_models(self) -> None:
        """Load globally blocked models from BLOCKED_MODELS environment variable."""
        env_value = os.getenv("BLOCKED_MODELS")
        
        if env_value is None or env_value == "":
            logger.debug("BLOCKED_MODELS not set - no models explicitly blocked")
            return
        
        # Parse comma-separated list
        for model in env_value.split(","):
            cleaned = model.strip().lower()
            if cleaned:
                self.blocked_models.add(cleaned)
        
        if self.blocked_models:
            logger.info(f"Globally blocked models: {sorted(self.blocked_models)}")
    
    def _load_disabled_patterns(self) -> None:
        """Load disabled model patterns from DISABLED_MODEL_PATTERNS environment variable."""
        env_value = os.getenv("DISABLED_MODEL_PATTERNS")
        
        if env_value is None or env_value == "":
            logger.debug("DISABLED_MODEL_PATTERNS not set - no patterns disabled")
            return
        
        # Parse comma-separated list
        for pattern in env_value.split(","):
            cleaned = pattern.strip().lower()
            if cleaned:
                self.disabled_patterns.append(cleaned)
        
        if self.disabled_patterns:
            logger.info(f"Disabled model patterns: {self.disabled_patterns}")

    def validate_against_known_models(self, provider_instances: dict[ProviderType, any]) -> None:
        """
        Validate restrictions against known models from providers.

        This should be called after providers are initialized to warn about
        typos or invalid model names in the restriction lists.

        Args:
            provider_instances: Dictionary of provider type to provider instance
        """
        for provider_type, allowed_models in self.restrictions.items():
            provider = provider_instances.get(provider_type)
            if not provider:
                continue

            # Get all supported models using the clean polymorphic interface
            try:
                # Use list_all_known_models to get both aliases and their targets
                all_models = provider.list_all_known_models()
                supported_models = {model.lower() for model in all_models}
            except Exception as e:
                logger.debug(f"Could not get model list from {provider_type.value} provider: {e}")
                supported_models = set()

            # Check each allowed model
            for allowed_model in allowed_models:
                if allowed_model not in supported_models:
                    logger.warning(
                        f"Model '{allowed_model}' in {self.ENV_VARS[provider_type]} "
                        f"is not a recognized {provider_type.value} model. "
                        f"Please check for typos. Known models: {sorted(supported_models)}"
                    )

    def is_allowed(self, provider_type: ProviderType, model_name: str, original_name: Optional[str] = None) -> bool:
        """
        Check if a model is allowed for a specific provider.

        This checks in order:
        1. If model matches any disabled patterns - BLOCKED
        2. If model is in the blocked models list - BLOCKED
        3. If provider has allowed list and model is not in it - BLOCKED
        4. Otherwise - ALLOWED

        Args:
            provider_type: The provider type (OPENAI, GOOGLE, etc.)
            model_name: The canonical model name (after alias resolution)
            original_name: The original model name before alias resolution (optional)

        Returns:
            True if allowed, False if restricted
        """
        # Check both the resolved name and original name
        names_to_check = {model_name.lower()}
        if original_name and original_name.lower() != model_name.lower():
            names_to_check.add(original_name.lower())
        
        # First check disabled patterns (highest priority)
        for pattern in self.disabled_patterns:
            for name in names_to_check:
                if pattern in name:
                    logger.debug(f"Model '{name}' blocked by pattern '{pattern}'")
                    return False
        
        # Then check explicitly blocked models
        for name in names_to_check:
            if name in self.blocked_models:
                logger.debug(f"Model '{name}' is explicitly blocked")
                return False
        
        # Finally check provider-specific allowed lists
        if provider_type not in self.restrictions:
            # No restrictions for this provider - allowed
            return True

        allowed_set = self.restrictions[provider_type]

        if len(allowed_set) == 0:
            # Empty set - allowed
            return True

        # If any of the names is in the allowed set, it's allowed
        return any(name in allowed_set for name in names_to_check)

    def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]:
        """
        Get the set of allowed models for a provider.

        Args:
            provider_type: The provider type

        Returns:
            Set of allowed model names, or None if no restrictions
        """
        return self.restrictions.get(provider_type)

    def has_restrictions(self, provider_type: ProviderType) -> bool:
        """
        Check if a provider has any restrictions.

        Args:
            provider_type: The provider type

        Returns:
            True if restrictions exist, False otherwise
        """
        return provider_type in self.restrictions

    def filter_models(self, provider_type: ProviderType, models: list[str]) -> list[str]:
        """
        Filter a list of models based on restrictions.

        This applies all restrictions: patterns, blocked models, and allowed lists.

        Args:
            provider_type: The provider type
            models: List of model names to filter

        Returns:
            Filtered list containing only allowed models
        """
        # Always check is_allowed which handles all restriction types
        return [m for m in models if self.is_allowed(provider_type, m)]

    def get_restriction_summary(self) -> dict[str, any]:
        """
        Get a summary of all restrictions for logging/debugging.

        Returns:
            Dictionary with provider names and their restrictions
        """
        summary = {}
        
        # Provider-specific allowed models
        for provider_type, allowed_set in self.restrictions.items():
            if allowed_set:
                summary[provider_type.value] = sorted(allowed_set)
            else:
                summary[provider_type.value] = "none (provider disabled)"
        
        # Global blocks
        if self.blocked_models:
            summary["blocked_models"] = sorted(self.blocked_models)
        
        if self.disabled_patterns:
            summary["disabled_patterns"] = self.disabled_patterns

        return summary


# Global instance (singleton pattern)
_restriction_service: Optional[ModelRestrictionService] = None


def get_restriction_service() -> ModelRestrictionService:
    """
    Get the global restriction service instance.

    Returns:
        The singleton ModelRestrictionService instance
    """
    global _restriction_service
    if _restriction_service is None:
        _restriction_service = ModelRestrictionService()
    return _restriction_service
