"""FastAPI application for trigger server."""

from __future__ import annotations

import logging
import os
from typing import Any, Callable, Dict, List, Optional

from fastapi import FastAPI, HTTPException, Request, Depends
from langgraph_sdk import get_client
from langchain_auth.client import Client
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response

from .decorators import TriggerTemplate
from .database import create_database, TriggerDatabaseInterface
from .cron_manager import CronTriggerManager

logger = logging.getLogger(__name__)


class AuthenticationMiddleware(BaseHTTPMiddleware):
    """Middleware to handle authentication for API endpoints."""
    
    def __init__(self, app, auth_handler: Callable):
        super().__init__(app)
        self.auth_handler = auth_handler
    
    async def dispatch(self, request: Request, call_next):
        # Skip auth for webhooks, health/root endpoints, and OPTIONS requests
        if (request.url.path.startswith("/webhooks/") or 
            request.url.path in ["/", "/health"] or
            request.method == "OPTIONS"):
            return await call_next(request)
        
        try:
            # Run mandatory custom authentication
            identity = await self.auth_handler({}, dict(request.headers))
            
            if not identity or not identity.get("identity"):
                return Response(
                    content='{"detail": "Authentication required"}',
                    status_code=401,
                    media_type="application/json"
                )
            
            # Store identity in request state for endpoints to access
            request.state.current_user = identity
            
        except Exception as e:
            logger.error(f"Authentication middleware error: {e}")
            return Response(
                content='{"detail": "Authentication failed"}',
                status_code=401,
                media_type="application/json"
            )
        
        return await call_next(request)


def get_current_user(request: Request) -> Dict[str, Any]:
    """FastAPI dependency to get the current authenticated user."""
    if not hasattr(request.state, "current_user"):
        raise HTTPException(status_code=401, detail="Authentication required")
    return request.state.current_user


class TriggerServer:
    """FastAPI application for trigger webhooks."""
    
    def __init__(
        self,
        auth_handler: Callable,
    ):
        self.app = FastAPI(
            title="Triggers Server",
            description="Event-driven triggers framework",
            version="0.1.0"
        )
        
        self.database = create_database()
        self.auth_handler = auth_handler
        
        # LangGraph configuration
        self.langgraph_api_url = os.getenv("LANGGRAPH_API_URL")
        self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
        
        if not self.langgraph_api_url:
            raise ValueError("LANGGRAPH_API_URL environment variable is required")
        
        self.langgraph_api_url = self.langgraph_api_url.rstrip("/")

        # Initialize LangGraph SDK client
        self.langgraph_client = get_client(url=self.langgraph_api_url, api_key=self.langsmith_api_key)

        # Initialize LangChain auth client
        langchain_api_key = os.getenv("LANGCHAIN_API_KEY")
        if langchain_api_key:
            self.langchain_auth_client = Client(api_key=langchain_api_key)
            logger.info("✓ LangChain auth client initialized")
        else:
            self.langchain_auth_client = None
            logger.warning("LANGCHAIN_API_KEY not found - OAuth token injection disabled")
        
        self.triggers: List[TriggerTemplate] = []
        
        # Initialize CronTriggerManager
        self.cron_manager = CronTriggerManager(self)
        
        # Setup authentication middleware
        self.app.add_middleware(AuthenticationMiddleware, auth_handler=auth_handler)
        
        # Setup routes
        self._setup_routes()
        
        # Add startup and shutdown events
        @self.app.on_event("startup")
        async def startup_event():
            await self.ensure_trigger_templates()
            await self.cron_manager.start()
        
        @self.app.on_event("shutdown")
        async def shutdown_event():
            await self.cron_manager.shutdown()
    
    def add_trigger(self, trigger: TriggerTemplate) -> None:
        """Add a trigger template to the app."""
        # Check for duplicate IDs
        if any(t.id == trigger.id for t in self.triggers):
            raise ValueError(f"Trigger with id '{trigger.id}' already exists")
        
        self.triggers.append(trigger)

        if trigger.trigger_handler:
            async def handler_endpoint(request: Request) -> Dict[str, Any]:
                return await self._handle_request(trigger, request)
            
            handler_path = f"/webhooks/{trigger.id}"
            self.app.post(handler_path)(handler_endpoint)
            logger.info(f"Added handler route: POST {handler_path}")
        
        logger.info(f"Registered trigger template in memory: {trigger.name} ({trigger.id})")
    
    async def ensure_trigger_templates(self) -> None:
        """Ensure all registered trigger templates exist in the database."""
        for trigger in self.triggers:
            existing = await self.database.get_trigger_template(trigger.id)
            if not existing:
                logger.info(f"Creating new trigger template in database: {trigger.name} ({trigger.id})")
                await self.database.create_trigger_template(
                    id=trigger.id,
                    name=trigger.name,
                    description=trigger.description,
                    registration_schema=trigger.registration_model.model_json_schema()
                )
                logger.info(f"✓ Successfully created trigger template: {trigger.name} ({trigger.id})")
            else:
                logger.info(f"✓ Trigger template already exists in database: {trigger.name} ({trigger.id})")
    
    def add_triggers(self, triggers: List[TriggerTemplate]) -> None:
        """Add multiple triggers."""
        for trigger in triggers:
            self.add_trigger(trigger)
    
    def _setup_routes(self) -> None:
        """Setup built-in API routes."""
        
        @self.app.get("/")
        async def root() -> Dict[str, str]:
            return {"message": "Triggers Server", "version": "0.1.0"}
        
        @self.app.get("/health")
        async def health() -> Dict[str, str]:
            return {"status": "healthy"}
        
        @self.app.get("/api/triggers")
        async def api_list_triggers() -> Dict[str, Any]:
            """List available trigger templates."""
            templates = await self.database.get_trigger_templates()
            trigger_list = []
            for template in templates:
                trigger_list.append({
                    "id": template["id"],
                    "displayName": template["name"],
                    "description": template["description"],
                    "path": "/api/triggers/registrations",
                    "method": "POST",
                    "payloadSchema": template.get("registration_schema", {}),
                })
            
            return {
                "success": True,
                "data": trigger_list
            }
        
        @self.app.get("/api/triggers/registrations")
        async def api_list_registrations(current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
            """List user's trigger registrations (user-scoped)."""
            try:
                user_id = current_user["identity"]
                
                # Get user's trigger registrations using new schema
                user_registrations = await self.database.get_user_trigger_registrations(user_id)
                
                # Format response to match expected structure
                registrations = []
                for reg in user_registrations:
                    # Get linked agent IDs
                    linked_agent_ids = await self.database.get_agents_for_trigger(reg["id"])
                    
                    registrations.append({
                        "id": reg["id"],
                        "user_id": reg["user_id"],
                        "template_id": reg.get("trigger_templates", {}).get("id"),
                        "resource": reg["resource"],
                        "linked_assistant_ids": linked_agent_ids,  # For backward compatibility
                        "created_at": reg["created_at"]
                    })
                
                return {
                    "success": True,
                    "data": registrations
                }
                
            except HTTPException:
                raise
            except Exception as e:
                logger.error(f"Error listing registrations: {e}")
                raise HTTPException(status_code=500, detail=str(e))
        
        @self.app.post("/api/triggers/registrations")
        async def api_create_registration(request: Request, current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
            """Create a new trigger registration."""
            try:
                payload = await request.json()
                logger.info(f"Registration payload received: {payload}")
                
                user_id = current_user["identity"]
                trigger_id = payload.get("type")
                if not trigger_id:
                    raise HTTPException(status_code=400, detail="Missing required field: type")
                
                trigger = next((t for t in self.triggers if t.id == trigger_id), None)
                if not trigger:
                    raise HTTPException(status_code=400, detail=f"Unknown trigger type: {trigger_id}")
                
                # Parse payload into registration model first
                try:
                    registration_instance = trigger.registration_model(**payload)
                except Exception as e:
                    raise HTTPException(
                        status_code=400,
                        detail=f"Invalid payload for trigger: {str(e)}"
                    )
                
                # Check for duplicate registration based on resource data
                resource_dict = registration_instance.model_dump()
                existing_registration = await self.database.find_registration_by_resource(
                    template_id=trigger.id,
                    resource_data=resource_dict
                )
                
                # TODO(sam) figure out how to allow duplicates across users.....very unnatural constraint to have
                if existing_registration:
                    raise HTTPException(
                        status_code=400,
                        detail=f"A registration with this configuration already exists for trigger type '{trigger.id}'. Registration ID: {existing_registration.get('id')}"
                    )
                
                
                # Call the trigger's registration handler with parsed registration model
                result = await trigger.registration_handler(user_id, self.langchain_auth_client, registration_instance)
                
                # Check if handler requested to skip registration (e.g., for OAuth or URL verification)
                if not result.create_registration:
                    logger.info(f"Registration handler requested to skip database creation")
                    from fastapi import Response
                    import json
                    return Response(
                        content=json.dumps(result.response_body),
                        status_code=result.status_code,
                        media_type="application/json"
                    )
                
                resource_dict = registration_instance.model_dump()

                registration = await self.database.create_trigger_registration(
                    user_id=user_id,
                    template_id=trigger.id,
                    resource=resource_dict,
                    metadata=result.metadata
                )
                
                if not registration:
                    raise HTTPException(status_code=500, detail="Failed to create trigger registration")
                
                # Reload cron manager to pick up any new cron registrations
                await self.cron_manager.reload_from_database()
                
                # Return registration result
                return {
                    "success": True,
                    "data": registration,
                    "metadata": result.metadata
                }
                
            except HTTPException:
                raise
            except Exception as e:
                logger.exception(f"Error creating trigger registration: {e}")
                raise HTTPException(status_code=500, detail=str(e))
        
        @self.app.get("/api/triggers/registrations/{registration_id}/agents")
        async def api_list_registration_agents(registration_id: str, current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
            """List agents linked to this registration."""
            try:
                user_id = current_user["identity"]
                
                # Get the specific trigger registration
                trigger = await self.database.get_user_trigger(user_id, registration_id, token)
                if not trigger:
                    raise HTTPException(status_code=404, detail="Trigger registration not found or access denied")
                
                # Return the linked agent IDs
                return {
                    "success": True,
                    "data": trigger.get("linked_assistant_ids", [])
                }
                
            except HTTPException:
                raise
            except Exception as e:
                logger.error(f"Error getting registration agents: {e}")
                raise HTTPException(status_code=500, detail=str(e))
        
        @self.app.post("/api/triggers/registrations/{registration_id}/agents/{agent_id}")
        async def api_add_agent_to_trigger(registration_id: str, agent_id: str, request: Request, current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
            """Add an agent to a trigger registration."""
            try:
                # Parse request body for field selection
                try:
                    body = await request.json()
                    field_selection = body.get("field_selection")
                except:
                    field_selection = None
                
                user_id = current_user["identity"]
                
                # Verify the trigger registration exists and belongs to the user
                registration = await self.database.get_trigger_registration(registration_id, user_id)
                if not registration:
                    raise HTTPException(status_code=404, detail="Trigger registration not found or access denied")
                
                # Link the agent to the trigger
                success = await self.database.link_agent_to_trigger(
                    agent_id=agent_id,
                    registration_id=registration_id,
                    created_by=user_id,
                    field_selection=field_selection
                )
                
                if not success:
                    raise HTTPException(status_code=500, detail="Failed to link agent to trigger")
                
                return {
                    "success": True,
                    "message": f"Successfully linked agent {agent_id} to trigger {registration_id}",
                    "data": {
                        "registration_id": registration_id,
                        "agent_id": agent_id
                    }
                }
                
            except HTTPException:
                raise
            except Exception as e:
                logger.error(f"Error linking agent to trigger: {e}")
                raise HTTPException(status_code=500, detail=str(e))
        
        @self.app.delete("/api/triggers/registrations/{registration_id}/agents/{agent_id}")
        async def api_remove_agent_from_trigger(registration_id: str, agent_id: str, current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
            """Remove an agent from a trigger registration."""
            try:
                user_id = current_user["identity"]
                
                # Verify the trigger registration exists and belongs to the user
                registration = await self.database.get_trigger_registration(registration_id, user_id)
                if not registration:
                    raise HTTPException(status_code=404, detail="Trigger registration not found or access denied")
                
                # Unlink the agent from the trigger
                success = await self.database.unlink_agent_from_trigger(
                    agent_id=agent_id,
                    registration_id=registration_id
                )
                
                if not success:
                    raise HTTPException(status_code=500, detail="Failed to unlink agent from trigger")
                
                return {
                    "success": True,
                    "message": f"Successfully unlinked agent {agent_id} from trigger {registration_id}",
                    "data": {
                        "registration_id": registration_id,
                        "agent_id": agent_id
                    }
                }
                
            except HTTPException:
                raise
            except Exception as e:
                logger.error(f"Error unlinking agent from trigger: {e}")
                raise HTTPException(status_code=500, detail=str(e))
    
    
    async def _handle_request(
        self, 
        trigger: TriggerTemplate, 
        request: Request
    ) -> Dict[str, Any]:
        """Handle an incoming request with a handler function."""
        try:
            
            # Parse request data
            if request.method == "POST":
                if request.headers.get("content-type", "").startswith("application/json"):
                    payload = await request.json()
                else:
                    # Handle form data or other content types
                    body = await request.body()
                    payload = {"raw_body": body.decode("utf-8") if body else ""}
            else:
                payload = dict(request.query_params)

            query_params = dict(request.query_params)
            result = await trigger.trigger_handler(payload, query_params, self.database, self.langchain_auth_client)
            if not result.invoke_agent:
                return result.response_body

            registration_id = result.registration["id"]
            agent_links = await self.database.get_agents_for_trigger(registration_id)

            agents_invoked = 0
            for agent_link in agent_links:
                agent_id = agent_link if isinstance(agent_link, str) else agent_link.get("agent_id")

                agent_input = {
                    "messages": [
                        {"role": "human", "content": result.agent_message}
                    ]
                }

                try:
                    success = await self._invoke_agent(
                        agent_id=agent_id,
                        user_id=result.registration["user_id"],
                        input_data=agent_input,
                    )
                    if success:
                        agents_invoked += 1
                except Exception as e:
                    logger.error(f"Error invoking agent {agent_id}: {e}", exc_info=True)
            logger.info(f"Processed trigger handler, invoked {agents_invoked} agents")
            
            return {
                "success": True,
                "agents_invoked": agents_invoked
            }
            
        except HTTPException:
            raise
        except Exception as e:
            logger.error(f"Error in trigger handler: {e}", exc_info=True)
            raise HTTPException(
                status_code=500,
                detail=f"Trigger processing failed: {str(e)}"
            )
    
    
    async def _invoke_agent(
        self,
        agent_id: str,
        user_id: str,
        input_data: Dict[str, Any],
    ) -> bool:
        """Invoke LangGraph agent using the SDK."""
        logger.info(f"Invoking LangGraph agent {agent_id} for user {user_id}")
        
        try:
            headers = {
                "x-auth-scheme": "oap-trigger",
                "x-supabase-user-id": user_id,
            }
            
            thread = await self.langgraph_client.threads.create(
                metadata={
                    "triggered_by": "langchain-triggers",
                    "user_id": user_id,
                },
                headers=headers,
            )
            logger.info(f"Created thread {thread['thread_id']} for agent {agent_id}")

            run = await self.langgraph_client.runs.create(
                thread_id=thread['thread_id'],
                assistant_id=agent_id,
                input=input_data,
                metadata={
                    "triggered_by": "langchain-triggers",
                    "user_id": user_id,
                },
                headers=headers,
            )
            
            logger.info(f"Successfully invoked agent {agent_id}, run_id: {run['run_id']}, thread_id: {run['thread_id']}")
            return True
            
        except Exception as e:
            # Handle 404s (agent not found) as warnings, not errors
            if hasattr(e, 'response') and getattr(e.response, 'status_code', None) == 404:
                logger.warning(f"Agent {agent_id} not found (404) - agent may have been deleted or moved")
                return False
            else:
                logger.error(f"Error invoking agent {agent_id}: {e}")
                raise
    
    def get_app(self) -> FastAPI:
        """Get the FastAPI app instance."""
        return self.app