#!/usr/bin/env python3
import os
import sys
import logging
import subprocess
import json
import time
import requests
from typing import Dict, Any, Optional
import paho.mqtt.client as mqtt
from dotenv import load_dotenv

# Setup enhanced logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('/var/log/multi_llm_agent.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger('MultiLLMAgent')

class MultiLLMAgent:
    def __init__(self):
        self.config = {
            'name': 'MultiLLMAgent',
            'version': '2.0.0', 
            'host': 'HAS',
            'ip': '192.168.0.58',
            'mode': 'admin',  # admin | worker
            'ollama_host': '192.168.0.41:11434',
            'zen_coordinator': 'http://localhost:8020'
        }
        
        # Model selection thresholds
        self.thresholds = {
            'quick_task_minutes': 1,
            'complex_task_minutes': 5,
            'security_critical': True
        }
        
        # MQTT setup
        self.mqtt_client = mqtt.Client('multi_llm_agent')
        self.mqtt_client.on_connect = self.on_mqtt_connect
        self.mqtt_client.on_message = self.on_mqtt_message
        
        # Master agent availability flag
        self.master_available = True
        self.last_master_ping = time.time()

    def select_model(self, task: Dict[str, Any]) -> str:
        """Dynamic model selection based on task complexity and availability"""
        
        estimated_time = task.get('estimated_time_min', 1)
        complexity = task.get('complexity', 'standard')
        security_threat = task.get('security_threat', False)
        
        logger.info(f"Selecting model for task: {complexity}, time: {estimated_time}min, security: {security_threat}")
        
        # Security threats always use Gemini
        if security_threat:
            if complexity == 'critical':
                return 'gemini-pro'
            return 'gemini-flash'
        
        # Standard operation mode (master available)
        if self.master_available:
            if estimated_time <= self.thresholds['quick_task_minutes']:
                return 'ollama-qwen2'  # Fast local model
            elif complexity == 'standard':
                return 'claude-haiku'  # Balanced cost/performance
            else:
                return 'ask-master'  # Let Claude-Code decide
        
        # Fallback mode (master unavailable)
        else:
            if estimated_time <= 3:  # Expanded scope in fallback
                return 'ollama-qwen2'
            elif complexity == 'complex':
                return 'gemini-flash'
            else:
                return 'gemini-pro'  # Independent thinking

    def check_ollama_health(self) -> bool:
        """Check if OLLAMA server is available"""
        try:
            response = requests.get(f"http://{self.config['ollama_host']}/api/tags", timeout=5)
            return response.status_code == 200
        except:
            return False

    def execute_with_model(self, task: Dict[str, Any], model: str) -> Dict[str, Any]:
        """Execute task with selected model"""
        
        if model == 'ollama-qwen2':
            return self.execute_ollama(task)
        elif model == 'claude-haiku':
            return self.execute_claude_haiku(task)
        elif model.startswith('gemini'):
            return self.execute_gemini(task, model)
        elif model == 'ask-master':
            return self.ask_master_agent(task)
        else:
            return {'error': f'Unknown model: {model}'}

    def execute_ollama(self, task: Dict[str, Any]) -> Dict[str, Any]:
        """Execute task using OLLAMA (fast, local)"""
        if not self.check_ollama_health():
            logger.warning("OLLAMA unavailable, falling back to Gemini Flash")
            return self.execute_gemini(task, 'gemini-flash')
            
        try:
            payload = {
                'model': 'qwen2:latest',
                'prompt': task.get('prompt', ''),
                'stream': False
            }
            response = requests.post(f"http://{self.config['ollama_host']}/api/generate", 
                                   json=payload, timeout=30)
            if response.status_code == 200:
                return {'model': 'ollama-qwen2', 'response': response.json()['response']}
        except Exception as e:
            logger.error(f"OLLAMA execution failed: {e}")
            return self.execute_gemini(task, 'gemini-flash')  # Fallback

    def execute_claude_haiku(self, task: Dict[str, Any]) -> Dict[str, Any]:
        """Execute via ZEN Coordinator (Claude Haiku)"""
        try:
            payload = {
                'tool': 'execute_command',
                'arguments': {
                    'command': task.get('command', 'echo "Task completed"')
                }
            }
            response = requests.post(f"{self.config['zen_coordinator']}/mcp", 
                                   json=payload, timeout=60)
            if response.status_code == 200:
                return {'model': 'claude-haiku', 'response': response.json()}
        except Exception as e:
            logger.error(f"Claude Haiku execution failed: {e}")
            return self.execute_gemini(task, 'gemini-flash')

    def execute_gemini(self, task: Dict[str, Any], model: str) -> Dict[str, Any]:
        """Execute using Gemini (via MCP research or direct API)"""
        try:
            # Use ZEN Coordinator research MCP for Gemini integration
            payload = {
                'tool': 'research_query',
                'arguments': {
                    'query': task.get('prompt', 'System status check'),
                    'model': model
                }
            }
            response = requests.post(f"{self.config['zen_coordinator']}/mcp", 
                                   json=payload, timeout=90)
            if response.status_code == 200:
                return {'model': model, 'response': response.json()}
        except Exception as e:
            logger.error(f"Gemini execution failed: {e}")
            return {'error': f'All models failed for task: {task}'}

    def ask_master_agent(self, task: Dict[str, Any]) -> Dict[str, Any]:
        """Escalate to master Claude-Code agent"""
        logger.info("Escalating complex task to master agent")
        self.mqtt_client.publish('has/claude_code/escalation', json.dumps(task))
        return {'model': 'claude-master', 'status': 'escalated'}

    def on_mqtt_connect(self, client, userdata, flags, rc):
        logger.info(f'Multi-LLM Agent connected to MQTT with result code {rc}')
        client.subscribe('has/multi_llm/commands')
        client.subscribe('has/multi_llm/tasks')
        client.subscribe('has/master/heartbeat')  # Monitor master agent

    def on_mqtt_message(self, client, userdata, msg):
        try:
            topic = msg.topic
            payload = json.loads(msg.payload.decode())
            
            logger.info(f'Received message on {topic}: {payload}')
            
            if topic == 'has/multi_llm/commands':
                self.handle_command(payload)
            elif topic == 'has/multi_llm/tasks':
                self.handle_task(payload)
            elif topic == 'has/master/heartbeat':
                self.last_master_ping = time.time()
                self.master_available = True
                
        except Exception as e:
            logger.error(f'Error processing MQTT message: {e}')

    def handle_command(self, command: Dict[str, Any]):
        """Handle administrative commands"""
        cmd_type = command.get('command')
        
        if cmd_type == 'switch_mode':
            new_mode = command.get('mode', 'admin')
            self.config['mode'] = new_mode
            logger.info(f"Switched to {new_mode} mode")
            self.mqtt_client.publish('has/multi_llm/status', 
                                   json.dumps({'mode': new_mode, 'timestamp': time.time()}))
                                   
        elif cmd_type == 'system_health':
            health = self.get_system_health()
            self.mqtt_client.publish('has/multi_llm/health', json.dumps(health))
            
        elif cmd_type == 'model_test':
            test_results = self.test_all_models()
            self.mqtt_client.publish('has/multi_llm/test_results', json.dumps(test_results))

    def handle_task(self, task: Dict[str, Any]):
        """Handle work tasks with dynamic model selection"""
        start_time = time.time()
        
        # Select appropriate model
        selected_model = self.select_model(task)
        logger.info(f"Selected model: {selected_model} for task: {task.get('id', 'unknown')}")
        
        # Execute task
        result = self.execute_with_model(task, selected_model)
        
        # Add execution metadata
        result['execution_time'] = time.time() - start_time
        result['task_id'] = task.get('id')
        result['selected_model'] = selected_model
        
        # Publish result
        self.mqtt_client.publish('has/multi_llm/results', json.dumps(result))

    def get_system_health(self) -> Dict[str, Any]:
        """Comprehensive system health check"""
        return {
            'agent_uptime': time.time() - self.start_time,
            'ollama_available': self.check_ollama_health(),
            'zen_coordinator_available': self.check_zen_health(),
            'master_agent_available': self.master_available,
            'last_master_ping': self.last_master_ping,
            'current_mode': self.config['mode'],
            'memory_usage': self.get_memory_usage(),
            'timestamp': time.time()
        }

    def check_zen_health(self) -> bool:
        """Check ZEN Coordinator health"""
        try:
            response = requests.get(f"{self.config['zen_coordinator']}/health", timeout=5)
            return response.status_code == 200
        except:
            return False

    def get_memory_usage(self) -> str:
        """Get memory usage info"""
        try:
            result = subprocess.run(['free', '-h'], capture_output=True, text=True)
            return result.stdout.split('\n')[1].split()[2]  # Used memory
        except:
            return 'unknown'

    def test_all_models(self) -> Dict[str, Any]:
        """Test all available models"""
        test_task = {
            'prompt': 'System status check - respond with OK',
            'complexity': 'simple',
            'estimated_time_min': 1
        }
        
        results = {}
        for model in ['ollama-qwen2', 'claude-haiku', 'gemini-flash']:
            try:
                start = time.time()
                result = self.execute_with_model(test_task, model)
                results[model] = {
                    'status': 'success' if 'error' not in result else 'failed',
                    'response_time': time.time() - start,
                    'response': result
                }
            except Exception as e:
                results[model] = {'status': 'failed', 'error': str(e)}
                
        return results

    def run_admin_loop(self):
        """Admin mode - continuous monitoring"""
        logger.info("Starting admin mode - continuous monitoring")
        
        while True:
            try:
                # Check master agent availability
                if time.time() - self.last_master_ping > 60:  # 60s timeout
                    self.master_available = False
                    logger.warning("Master agent seems unavailable")
                
                # Periodic health check
                if int(time.time()) % 300 == 0:  # Every 5 minutes
                    health = self.get_system_health()
                    self.mqtt_client.publish('has/multi_llm/periodic_health', json.dumps(health))
                
                time.sleep(1)
                
            except KeyboardInterrupt:
                logger.info("Admin loop interrupted")
                break
            except Exception as e:
                logger.error(f"Admin loop error: {e}")
                time.sleep(5)

    def run_worker_mode(self):
        """Worker mode - wait for tasks"""
        logger.info("Starting worker mode - waiting for tasks")
        self.mqtt_client.loop_forever()

    def run(self):
        """Main agent loop"""
        self.start_time = time.time()
        logger.info(f"Multi-LLM Agent starting in {self.config['mode']} mode...")
        
        # MQTT connection
        self.mqtt_client.connect('localhost', 1883, 60)
        
        # Publish initial status
        init_status = {
            'config': self.config,
            'startup_time': self.start_time,
            'available_models': ['ollama-qwen2', 'claude-haiku', 'gemini-flash', 'gemini-pro']
        }
        self.mqtt_client.publish('has/multi_llm/init', json.dumps(init_status))
        
        # Start appropriate mode
        if self.config['mode'] == 'admin':
            # Admin mode: start MQTT in background + admin loop
            self.mqtt_client.loop_start()
            self.run_admin_loop()
        else:
            # Worker mode: pure MQTT event loop
            self.run_worker_mode()

if __name__ == '__main__':
    # Parse command line arguments for mode
    mode = 'admin'  # default
    if len(sys.argv) > 1:
        mode = sys.argv[1]
    
    agent = MultiLLMAgent()
    agent.config['mode'] = mode
    agent.run()

    def wake_workstation(self):
        """Wake up workstation using Wake-on-LAN"""
        logger.info("Executing Wake-on-LAN for workstation")
        try:
            result = subprocess.run(['/usr/local/bin/wake-workstation.sh'], 
                                  capture_output=True, text=True, timeout=90)
            
            response = {
                'command': 'wake_workstation',
                'status': 'success' if result.returncode == 0 else 'failed',
                'output': result.stdout,
                'error': result.stderr if result.stderr else None,
                'timestamp': time.time()
            }
            
            self.mqtt_client.publish('has/multi_llm/wake_result', json.dumps(response))
            logger.info(f"WoL result: {response['status']}")
            
        except subprocess.TimeoutExpired:
            error_response = {
                'command': 'wake_workstation', 
                'status': 'timeout',
                'message': 'Wake-on-LAN timed out after 90 seconds',
                'timestamp': time.time()
            }
            self.mqtt_client.publish('has/multi_llm/wake_result', json.dumps(error_response))
            logger.error("WoL operation timed out")
            
        except Exception as e:
            error_response = {
                'command': 'wake_workstation',
                'status': 'error', 
                'message': str(e),
                'timestamp': time.time()
            }
            self.mqtt_client.publish('has/multi_llm/wake_result', json.dumps(error_response))
            logger.error(f"WoL operation failed: {e}")
