#!/usr/bin/env python3
"""
Scorpius Vulnerability Detector
Core vulnerability detection engine with AI-powered pattern recognition
"""

import re
import time
import hashlib
from typing import Dict, List, Any, Optional
from datetime import datetime
import logging

logger = logging.getLogger(__name__)

class VulnerabilityDetector:
    """
    Core vulnerability detection engine
    Combines rule-based detection with AI predictions
    """
    
    def __init__(self, learning_system=None):
        self.learning_system = learning_system
        self.detection_rules = {}
        self.vulnerability_patterns = {}
        
        # Initialize core detection patterns
        self._init_detection_patterns()
    
    def _init_detection_patterns(self):
        """Initialize core vulnerability detection patterns"""
        
        self.vulnerability_patterns = {
            'reentrancy': {
                'patterns': [
                    r'\.call\{value:\s*\w+\}',
                    r'\.transfer\s*\(',
                    r'\.send\s*\(',
                    r'msg\.sender\.call'
                ],
                'context_patterns': [
                    r'balances\[.*\]\s*[-=]',
                    r'state.*after.*call',
                    r'external.*call.*before.*update'
                ],
                'severity': 'High',
                'confidence': 0.85
            },
            'access_control': {
                'patterns': [
                    r'function\s+\w+.*external.*{[^}]*(?!onlyOwner|require\s*\(\s*msg\.sender)',
                    r'admin\s*=\s*\w+',
                    r'owner\s*=\s*\w+',
                    r'function.*setAdmin.*{[^}]*(?!require)'
                ],
                'context_patterns': [
                    r'missing.*modifier',
                    r'unauthorized',
                    r'privilege.*escalation'
                ],
                'severity': 'High',
                'confidence': 0.80
            },
            'oracle_manipulation': {
                'patterns': [
                    r'\.getPrice\s*\(\)',
                    r'\.latestAnswer\s*\(\)',
                    r'getReserves\s*\(\)',
                    r'oracle\.\w+\s*\(\)'
                ],
                'context_patterns': [
                    r'single.*source',
                    r'spot.*price',
                    r'manipulat',
                    r'(?!TWAP|average|median)'
                ],
                'severity': 'Critical',
                'confidence': 0.90
            },
            'flash_loan_attack': {
                'patterns': [
                    r'flashLoan',
                    r'flash.*loan',
                    r'borrow.*repay',
                    r'flashmint'
                ],
                'context_patterns': [
                    r'price.*manipulat',
                    r'arbitrage',
                    r'atomic.*transaction'
                ],
                'severity': 'Critical',
                'confidence': 0.88
            },
            'integer_overflow': {
                'patterns': [
                    r'\+\s*\w+(?!.*SafeMath)',
                    r'\-\s*\w+(?!.*SafeMath)',
                    r'\*\s*\w+(?!.*SafeMath)',
                    r'\/\s*\w+(?!.*SafeMath)'
                ],
                'context_patterns': [
                    r'pragma.*solidity.*0\.[0-7]',
                    r'(?!.*SafeMath)',
                    r'(?!.*checked)'
                ],
                'severity': 'Medium',
                'confidence': 0.75
            },
            'governance_attack': {
                'patterns': [
                    r'vote\s*\(',
                    r'proposal',
                    r'governance.*token',
                    r'voting.*power'
                ],
                'context_patterns': [
                    r'flash.*loan',
                    r'borrow.*vote',
                    r'delegate'
                ],
                'severity': 'High',
                'confidence': 0.82
            }
        }
    
    async def initialize(self):
        """Initialize the vulnerability detector"""
        try:
            logger.info("🔍 Vulnerability detector initialized")
        except Exception as e:
            logger.error(f"Failed to initialize vulnerability detector: {e}")
            raise
    
    async def detect_vulnerabilities(self, contract_code: str, contract_path: str = None) -> List[Dict[str, Any]]:
        """
        Detect vulnerabilities in smart contract code
        
        Args:
            contract_code: Solidity contract source code
            contract_path: Optional path to contract file
            
        Returns:
            List of detected vulnerabilities
        """
        
        vulnerabilities = []
        
        try:
            # Rule-based detection
            rule_based_vulns = await self._rule_based_detection(contract_code)
            vulnerabilities.extend(rule_based_vulns)
            
            # AI-powered prediction if learning system is available
            if self.learning_system:
                ai_prediction = await self.learning_system.predict_vulnerability_type(contract_code)
                
                if ai_prediction['confidence'] > 0.5:
                    ai_vulnerability = {
                        'id': f"AI-{len(vulnerabilities)+1:03d}",
                        'type': ai_prediction['predicted_type'],
                        'severity': ai_prediction['predicted_severity'],
                        'confidence': ai_prediction['confidence'],
                        'description': f"AI-detected {ai_prediction['predicted_type']} vulnerability",
                        'recommendation': ai_prediction['recommendation'],
                        'source': 'ai_prediction',
                        'line_number': None,
                        'code_snippet': self._extract_relevant_code(contract_code, ai_prediction['predicted_type'])
                    }
                    vulnerabilities.append(ai_vulnerability)
            
            # Deduplicate similar vulnerabilities
            vulnerabilities = self._deduplicate_vulnerabilities(vulnerabilities)
            
            return vulnerabilities
            
        except Exception as e:
            logger.error(f"Vulnerability detection failed: {e}")
            return []
    
    async def _rule_based_detection(self, contract_code: str) -> List[Dict[str, Any]]:
        """Rule-based vulnerability detection"""
        
        vulnerabilities = []
        
        for vuln_type, pattern_config in self.vulnerability_patterns.items():
            # Check main patterns
            for pattern in pattern_config['patterns']:
                matches = list(re.finditer(pattern, contract_code, re.IGNORECASE | re.MULTILINE))
                
                if matches:
                    # Check context patterns for confirmation
                    context_confidence = self._check_context_patterns(
                        contract_code, matches[0], pattern_config.get('context_patterns', [])
                    )
                    
                    if context_confidence > 0.3:  # Minimum context confidence
                        vulnerability = {
                            'id': f"R-{len(vulnerabilities)+1:03d}",
                            'type': vuln_type,
                            'severity': pattern_config['severity'],
                            'confidence': min(0.98, pattern_config['confidence'] * context_confidence),
                            'description': f"Detected {vuln_type.replace('_', ' ')} vulnerability",
                            'recommendation': self._get_recommendation(vuln_type),
                            'source': 'rule_based',
                            'line_number': self._get_line_number(contract_code, matches[0].start()),
                            'code_snippet': self._extract_code_context(contract_code, matches[0])
                        }
                        vulnerabilities.append(vulnerability)
                        break  # Only report one instance per vulnerability type
        
        return vulnerabilities
    
    def _check_context_patterns(self, code: str, match, context_patterns: List[str]) -> float:
        """Check context patterns around a match"""
        
        if not context_patterns:
            return 1.0
        
        # Extract context around the match
        start = max(0, match.start() - 200)
        end = min(len(code), match.end() + 200)
        context = code[start:end]
        
        # Check how many context patterns match
        matches = 0
        for pattern in context_patterns:
            if re.search(pattern, context, re.IGNORECASE):
                matches += 1
        
        # Return confidence based on context match ratio
        return min(1.0, 0.5 + (matches / len(context_patterns)) * 0.5)
    
    def _get_line_number(self, code: str, position: int) -> int:
        """Get line number for a position in code"""
        return code[:position].count('\n') + 1
    
    def _extract_code_context(self, code: str, match) -> str:
        """Extract code context around a match"""
        
        lines = code.split('\n')
        match_line = self._get_line_number(code, match.start()) - 1
        
        # Extract 3 lines before and after
        start_line = max(0, match_line - 3)
        end_line = min(len(lines), match_line + 4)
        
        context_lines = lines[start_line:end_line]
        return '\n'.join(context_lines)
    
    def _extract_relevant_code(self, code: str, vuln_type: str) -> str:
        """Extract relevant code snippet for vulnerability type"""
        
        # Simple implementation - find function containing vulnerability keywords
        vuln_keywords = {
            'reentrancy': ['call{', '.call(', '.send(', '.transfer('],
            'access_control': ['admin', 'owner', 'setAdmin', 'onlyOwner'],
            'oracle_manipulation': ['getPrice', 'oracle', 'latestAnswer'],
            'flash_loan_attack': ['flashLoan', 'borrow', 'repay'],
            'governance_attack': ['vote', 'proposal', 'governance']
        }
        
        keywords = vuln_keywords.get(vuln_type, [vuln_type])
        
        lines = code.split('\n')
        for i, line in enumerate(lines):
            if any(keyword in line for keyword in keywords):
                # Extract function containing this line
                start = max(0, i - 5)
                end = min(len(lines), i + 10)
                return '\n'.join(lines[start:end])
        
        return code[:500]  # Fallback to first 500 characters
    
    def _get_recommendation(self, vuln_type: str) -> str:
        """Get recommendation for vulnerability type"""
        
        recommendations = {
            'reentrancy': 'Use reentrancy guard modifier and follow checks-effects-interactions pattern',
            'access_control': 'Implement proper access control with onlyOwner or role-based modifiers',
            'oracle_manipulation': 'Use time-weighted average prices (TWAP) and multiple oracle sources',
            'flash_loan_attack': 'Implement flash loan protection and validate all price sources',
            'integer_overflow': 'Use SafeMath library or upgrade to Solidity 0.8+',
            'governance_attack': 'Add voting delays and implement flash loan protection for governance'
        }
        
        return recommendations.get(vuln_type, 'Implement security best practices for this vulnerability type')
    
    def _deduplicate_vulnerabilities(self, vulnerabilities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Remove duplicate vulnerabilities"""
        
        seen_types = set()
        deduplicated = []
        
        # Sort by confidence (highest first)
        sorted_vulns = sorted(vulnerabilities, key=lambda x: x.get('confidence', 0), reverse=True)
        
        for vuln in sorted_vulns:
            vuln_type = vuln.get('type')
            if vuln_type not in seen_types:
                seen_types.add(vuln_type)
                deduplicated.append(vuln)
        
        return deduplicated