from typing import Dict, Any, TypedDict, NotRequired
import traceback
from itertools import cycle
import argparse
import json
import asyncio
from datetime import datetime
from pathlib import Path
import logging

from ..pipeline.base import BaseRAGPipeline
from ..pipeline.chroma import ChromaRAGPipeline, ChromaRAGPipelineConfig
from ..pipeline.config import (
    DocumentSource,
    PDF_SOURCE,
    MHTML_SOURCE,
    HTML_SOURCE,
    TXT_SOURCE,
)

try:
    import panel as pn
    UI_FRAMEWORK_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
    UI_FRAMEWORK_AVAILABLE = False

    class NotImported:
        exc = ModuleNotFoundError(
            "Panel is not installed. "
            "Please install it using: pip install 'rag-agent[ui]'"
        )

        def __getattr__(self, item):
            raise self.exc

        def __call__(self, *args, **kwargs):
            raise self.exc

    globals().update(dict.fromkeys(
        [
            "panel",
            "pn",
        ],
        NotImported()
    ))

pn.extension('jsoneditor', notifications=True)

TERMINAL = pn.widgets.Terminal(
    "",
    options={"cursorBlink": True},
    sizing_mode="stretch_width",
)


class TerminalLogHandler(logging.Handler):
    """Custom logging handler that writes to the terminal widget."""
    def __init__(self, terminal: pn.widgets.Terminal, *args, **kwargs):
        self._terminal = terminal
        super().__init__(*args, **kwargs)
        self.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))

    def emit(self, record: logging.LogRecord):
        try:
            self._terminal.write(self.format(record) + '\n')
        except Exception:
            self.handleError(record)


handler = TerminalLogHandler(TERMINAL)
handler.setLevel(logging.INFO)

root_logger = logging.getLogger()
root_logger.addHandler(handler)
root_logger.setLevel(logging.INFO)


class ChatWithConfigurableMessages(pn.chat.ChatInterface):
    """A chat interface with context toggle functionality."""
    def __init__(self, message_kwargs: dict[str, Any] = {}, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._message_kwargs = message_kwargs

    def send(
        self,
        value: Any,
        user: str | None = None,
        avatar: str | None = None,
        respond: bool = True,
        trigger_post_hook: bool = True,
        **kwargs: dict[str, Any]
    ):
        """Send a message with context toggle functionality."""
        kwargs.update(self._message_kwargs)
        super(ChatWithConfigurableMessages, self).send(
            value=value,
            user=user,
            avatar=avatar,
            respond=respond,
            trigger_post_hook=trigger_post_hook,
            **kwargs
        )


class MessageWithContext(pn.chat.ChatMessage):
    """A message with context toggle functionality."""

    class MessageParams(TypedDict):
        object: str
        user: str
        avatar: str
        tooltip: str

    def __init__(self, object: Dict[str, Any], **kwargs):
        self._ctx_params = cycle([
            MessageWithContext.MessageParams(
                object=object["generate"]["answer"],
                user="Assistant",
                avatar="🤖",
                tooltip="Show context",
            ),
            MessageWithContext.MessageParams(
                object="**Context**:\n\n```json\n{context}\n```".format(context=json.dumps(
                    json.loads(object['generate']['context']),
                    indent=2
                )),
                user="System",
                avatar="⚙️",
                tooltip="Hide context",
            ),
        ])

        params = next(self._ctx_params)
        kwargs.update(**{k: v for k, v in params.items() if k not in ["tooltip"]})
        super().__init__(**kwargs)

        toggle_context_icon = pn.widgets.ToggleIcon(
            description=params["tooltip"],
            icon="zoom-in-area",
            active_icon="zoom-out-area",
        )
        self._icons_row.insert(0, toggle_context_icon)
        toggle_context_icon.param.watch(self._toggle_context, "value")

    def _toggle_context(self, event: Any) -> None:
        """Toggle between answer and context views."""
        params = next(self._ctx_params)
        event.obj.description = params["tooltip"]
        self.object = params["object"]
        self.param.update(
            user=params["user"],
            avatar=params["avatar"],
        )


class RAGChatInterface:
    """Main chat interface for RAG system."""

    class InitialEvent:
        def __init__(self, obj: pn.widgets.Widget):
            self.new = obj.value
            self.old = None
            self.type = "change"
            self.obj = obj
            self.cls = obj

    class EmbeddingConfig(TypedDict):
        """Embedding model configuration."""
        sources: pn.widgets.JSONEditor
        model: pn.widgets.TextInput
        device: pn.widgets.Select

    class TextSplittingConfig(TypedDict):
        """Text splitting configuration."""
        chunk_size: pn.widgets.IntInput
        chunk_overlap: pn.widgets.IntInput

    class VectorStoreConfig(TypedDict):
        """Vector store configuration."""
        vector_store: pn.widgets.Select
        persist_directory: pn.widgets.TextInput
        collection_name: pn.widgets.TextInput
        video_path: pn.widgets.TextInput
        index_path: pn.widgets.TextInput
        fps: pn.widgets.IntInput
        frame_size: pn.widgets.IntInput
        video_codec: pn.widgets.TextInput
        crf: pn.widgets.IntInput
        n_workers: pn.widgets.IntInput

    class RetrievalConfig(TypedDict):
        """Retrieval configuration.

        Optional fields are dependent on search_type:
        - score_threshold: Required for similarity_score_threshold
        - fetch_k, lambda_mult: Required for mmr
        """
        search_type: pn.widgets.Select
        k: pn.widgets.IntInput
        score_threshold: NotRequired[pn.widgets.FloatInput]
        fetch_k: NotRequired[pn.widgets.IntInput]
        lambda_mult: NotRequired[pn.widgets.FloatInput]

    class LLMConfig(TypedDict):
        """LLM configuration.

        Optional fields:
        - api_key: Required for huggingface provider
        """
        provider: pn.widgets.Select
        model: pn.widgets.TextInput
        temperature: pn.widgets.FloatInput
        api_key: NotRequired[pn.widgets.PasswordInput]

    class ConfigWidgets(TypedDict):
        """All configuration widgets."""
        sources: list['RAGChatInterface.SourceConfig']
        embedding: 'RAGChatInterface.EmbeddingConfig'
        text_splitting: 'RAGChatInterface.TextSplittingConfig'
        vector_store: 'RAGChatInterface.VectorStoreConfig'
        retrieval: 'RAGChatInterface.RetrievalConfig'
        llm: 'RAGChatInterface.LLMConfig'

    class ConfigSections(TypedDict):
        """UI sections for configuration."""
        embedding: pn.Column
        text_splitting: pn.Column
        vector_store: pn.Column
        retrieval: pn.Column
        llm: pn.Column

    def __init__(self, config: Any, terminal: pn.widgets.Terminal = TERMINAL):
        self._terminal = terminal
        self._pipeline: BaseRAGPipeline | None = None
        self._reload_documents = False
        self._pipeline_config: Any = config
        self._config_widgets = self._create_config_widgets()
        self._chat_interface = self._create_chat_interface()
        self._layout = self._create_layout()

    def _create_config_widgets(self) -> list[pn.widgets.Widget]:
        """Create the configuration section with widgets for all pipeline settings."""
        widgets: RAGChatInterface.ConfigWidgets = {
            "embedding": {
                "sources": pn.widgets.JSONEditor(
                    name="Embedding source configuration",
                    value={
                        k: [
                            v.as_dict() if isinstance(v, DocumentSource) else v for v in vv
                        ] for k, vv in getattr(self._pipeline_config, "pipeline_sources", {}).items()
                    },
                    mode="tree",
                    search=False,
                    menu=False,
                    sizing_mode="stretch_width",
                    schema={
                        "type": "object",
                        "description": "Sources to be used for the pipeline",
                        "patternProperties": {
                            "^.+$": {
                                "type": "array",
                                "items": {
                                    "oneOf": [
                                        {
                                            "type": "string",
                                            "enum": ["pdf", "mhtml", "html", "txt"]
                                        },
                                        {
                                            "type": "object",
                                            "properties": {
                                                "source_type": {
                                                    "type": "string",
                                                    "enum": ["txt", "pdf", "html"]
                                                },
                                                "meta_pattern": {
                                                    "type": "string",
                                                    "description": (
                                                        "Regex pattern for extracting metadata "
                                                        "from the source path"
                                                    )
                                                },
                                                "glob_pattern": {
                                                    "type": "string",
                                                    "description": "Glob pattern for matching files in directory"
                                                }
                                            },
                                            "required": ["source_type", "glob_pattern"],
                                            "additionalProperties": False
                                        }
                                    ]
                                }
                            }
                        },
                        "additionalProperties": False
                    },
                ),
                "model": pn.widgets.TextInput(
                    name="Embedding Model",
                    value=getattr(self._pipeline_config, "pipeline_embedding_model", ""),
                    placeholder="e.g., all-MiniLM-L6-v2"
                ),
                "device": pn.widgets.AutocompleteInput(
                    name="Embedding Device",
                    value=getattr(self._pipeline_config, "pipeline_embedding_model_kwargs", {}).get("device", "cuda"),
                    options=["cuda", "cpu", "mps"],
                    restrict=False,
                ),
            },
            "text_splitting": {
                "chunk_size": pn.widgets.IntInput(
                    name="Chunk Size",
                    value=getattr(self._pipeline_config, "pipeline_chunk_size", 0),
                    start=100,
                    step=100
                ),
                "chunk_overlap": pn.widgets.IntInput(
                    name="Chunk Overlap",
                    value=getattr(self._pipeline_config, "pipeline_chunk_overlap", 0),
                    start=0,
                    step=50
                ),
            },
            "vector_store": {
                "vector_store": pn.widgets.Select(
                    name="Vector Store",
                    value="chroma",
                    options=["chroma"]
                ),
                "persist_directory": pn.widgets.TextInput(
                    name="Persist Directory",
                    value=getattr(self._pipeline_config, "pipeline_persist_directory", ""),
                    placeholder="e.g., chroma_db"
                ),
                "collection_name": pn.widgets.TextInput(
                    name="Collection Name",
                    value=getattr(self._pipeline_config, "pipeline_collection_name", ""),
                    placeholder="e.g., default_collection"
                )
            },
            "retrieval": {
                "search_type": pn.widgets.Select(
                    name="Search Type",
                    value=getattr(self._pipeline_config, "pipeline_search_type", ""),
                    options=["similarity", "mmr", "similarity_score_threshold"]
                ),
                "k": pn.widgets.IntInput(
                    name="Number of Documents (k)",
                    value=getattr(self._pipeline_config, "pipeline_k", 0),
                    start=1,
                    step=1
                ),
                "score_threshold": pn.widgets.FloatInput(
                    name="Score Threshold",
                    value=getattr(self._pipeline_config, "pipeline_score_threshold", 0) or 0.5,
                    start=0.0,
                    end=1.0,
                    step=0.1
                ),
                "fetch_k": pn.widgets.IntInput(
                    name="Fetch k",
                    value=getattr(self._pipeline_config, "pipeline_fetch_k", 0) or 20,
                    start=1,
                    step=1
                ),
                "lambda_mult": pn.widgets.FloatInput(
                    name="Lambda Multiplier",
                    value=getattr(self._pipeline_config, "pipeline_lambda_mult", 0) or 0.5,
                    start=0.0,
                    end=1.0,
                    step=0.1
                ),
            },
            "llm": {
                "provider": pn.widgets.Select(
                    name="LLM Provider",
                    value=getattr(self._pipeline_config, "pipeline_llm_provider", ""),
                    options=["ollama", "huggingface"]
                ),
                "model": pn.widgets.TextInput(
                    name="LLM Model",
                    value=getattr(self._pipeline_config, "pipeline_llm_model", ""),
                    placeholder="e.g., mistral"
                ),
                "temperature": pn.widgets.FloatInput(
                    name="Temperature",
                    value=getattr(self._pipeline_config, "pipeline_llm_model_kwargs", {}).get("temperature", 0.3),
                    start=0.0,
                    end=1.0,
                    step=0.1
                ),
                "api_key": pn.widgets.PasswordInput(
                    name="API Key",
                    value=getattr(self._pipeline_config, "pipeline_llm_api_key", "") or "",
                    placeholder="Enter API key for Hugging Face"
                ),
            },
        }

        sections: RAGChatInterface.ConfigSections = {
            "embedding": pn.Column(
                widgets["embedding"]["sources"],
                widgets["embedding"]["model"],
                widgets["embedding"]["device"],
                name="Embedding Model Settings"
            ),
            "text_splitting": pn.Column(
                widgets["text_splitting"]["chunk_size"],
                widgets["text_splitting"]["chunk_overlap"],
                name="Text Splitting Settings"
            ),
            "vector_store": pn.Column(
                widgets["vector_store"]["vector_store"],
                name="Vector Store Settings"
            ),
            "retrieval": pn.Column(
                widgets["retrieval"]["search_type"],
                widgets["retrieval"]["k"],
                name="Retrieval Settings"
            ),
            "llm": pn.Column(
                widgets["llm"]["provider"],
                widgets["llm"]["model"],
                widgets["llm"]["temperature"],
                name="LLM Settings"
            ),
        }

        def update_llm_section(event: Any):
            match event.new:
                case "huggingface":
                    if widgets["llm"]["api_key"] not in sections["llm"]:
                        sections["llm"].append(widgets["llm"]["api_key"])
                case _:
                    if widgets["llm"]["api_key"] in sections["llm"]:
                        sections["llm"].remove(widgets["llm"]["api_key"])

        def update_retrieval_section(event: Any):
            for widget in [
                widgets["retrieval"]["score_threshold"],
                widgets["retrieval"]["fetch_k"],
                widgets["retrieval"]["lambda_mult"]
            ]:
                if widget in sections["retrieval"]:
                    sections["retrieval"].remove(widget)

            match event.new:
                case "mmr":
                    sections["retrieval"].append(widgets["retrieval"]["fetch_k"])
                    sections["retrieval"].append(widgets["retrieval"]["lambda_mult"])
                case "similarity_score_threshold":
                    sections["retrieval"].append(widgets["retrieval"]["score_threshold"])

        def update_vector_store_section(event: Any):
            """Update vector store section based on selected implementation."""
            # Remove all existing widgets except the selector
            for widget in widgets["vector_store"].values():
                if widget != widgets["vector_store"]["vector_store"] and widget in sections["vector_store"]:
                    sections["vector_store"].remove(widget)

            # Add relevant widgets based on selection
            match event.new:
                case "chroma":
                    sections["vector_store"].append(widgets["vector_store"]["persist_directory"])
                    sections["vector_store"].append(widgets["vector_store"]["collection_name"])

        update_llm_section(RAGChatInterface.InitialEvent(widgets["llm"]["provider"]))
        update_retrieval_section(RAGChatInterface.InitialEvent(widgets["retrieval"]["search_type"]))
        update_vector_store_section(RAGChatInterface.InitialEvent(widgets["vector_store"]["vector_store"]))

        widgets["llm"]["provider"].param.watch(update_llm_section, "value")
        widgets["retrieval"]["search_type"].param.watch(update_retrieval_section, "value")
        widgets["vector_store"]["vector_store"].param.watch(update_vector_store_section, "value")

        reload_documents_checkbox = pn.widgets.Checkbox(
            name="Reload Documents",
            value=False,
        )
        reload_documents_checkbox.param.watch(self._toggle_reload_documents, "value")
        apply_button = pn.widgets.Button(
            name="Apply Changes",
            button_type="primary",
            height=30,
        )
        apply_button.on_click(self._apply_config_changes)

        return [
            *sections.values(),
            reload_documents_checkbox,
            pn.indicators.LoadingSpinner(
                name="",
                value=False,
                size=30,
            ),
            apply_button,
        ]

    def _toggle_reload_documents(self, event: Any) -> None:
        """Toggle the reload documents checkbox."""
        self._reload_documents = event.new

    def _apply_config_changes(self, event: Any) -> None:
        """Apply configuration changes and reinitialize the pipeline."""
        event.obj.disabled = True
        spinner = self._config_widgets[-2]
        spinner.value = True
        try:
            embedding_section = self._config_widgets[0]
            text_splitting_section = self._config_widgets[1]
            vector_store_section = self._config_widgets[2]
            retrieval_section = self._config_widgets[3]
            llm_section = self._config_widgets[4]

            config_kwargs = {
                "pipeline_sources": {
                    k: [
                        DocumentSource.from_dict(v) if isinstance(v, dict) else
                        {
                            "pdf": PDF_SOURCE,
                            "mhtml": MHTML_SOURCE,
                            "html": HTML_SOURCE,
                            "txt": TXT_SOURCE,
                        }[v] if isinstance(v, str) and v in ["pdf", "mhtml", "html", "txt"] else v
                        for v in vv
                    ] for k, vv in embedding_section[0].value.items()
                },
                "pipeline_embedding_model": embedding_section[1].value,
                "pipeline_embedding_model_kwargs": {"device": embedding_section[2].value},
                "pipeline_chunk_size": text_splitting_section[0].value,
                "pipeline_chunk_overlap": text_splitting_section[1].value,
                "pipeline_vector_store": vector_store_section[0].value,
                "pipeline_search_type": retrieval_section[0].value,
                "pipeline_k": retrieval_section[1].value,
                "pipeline_llm_provider": llm_section[0].value,
                "pipeline_llm_model": llm_section[1].value,
                "pipeline_llm_model_kwargs": {"temperature": llm_section[2].value},
            }

            match retrieval_section[0].value:
                case "similarity_score_threshold":
                    config_kwargs["pipeline_score_threshold"] = float(retrieval_section[2].value)
                case "mmr":
                    config_kwargs["pipeline_fetch_k"] = int(retrieval_section[2].value)
                    config_kwargs["pipeline_lambda_mult"] = float(retrieval_section[3].value)

            match llm_section[0].value:
                case "huggingface":
                    config_kwargs["pipeline_llm_api_key"] = llm_section[3].value

            match vector_store_section[0].value:
                case "chroma":
                    config_kwargs["pipeline_persist_directory"] = vector_store_section[1].value
                    config_kwargs["pipeline_collection_name"] = vector_store_section[2].value
                    self._pipeline_config = ChromaRAGPipelineConfig(**config_kwargs)
                case _:
                    raise ValueError(f"Unsupported vector store: {vector_store_section[0].value}")

            self._pipeline = None
            asyncio.run(self._initialize_pipeline())

        except Exception as e:
            pn.state.notifications.error(
                f"Error applying configuration: {str(e)}\n\n{traceback.format_exc()}",
                duration=10000
            )
        else:
            pn.state.notifications.success("Configuration applied successfully!", duration=3000)
        finally:
            spinner.value = False
            event.obj.disabled = False

    async def _initialize_pipeline(self) -> None:
        """Initialize the RAG pipeline."""
        if self._pipeline is None:
            if isinstance(self._pipeline_config, ChromaRAGPipelineConfig):
                self._pipeline = ChromaRAGPipeline(config=self._pipeline_config)
            else:
                raise ValueError(f"Unsupported vector store: {self._pipeline_config.pipeline_vector_store}")

            if self._reload_documents:
                documents = await self._pipeline.load_documents()
                processed_documents = await self._pipeline.process_documents(documents)
                await self._pipeline.update_vectorstore(processed_documents)

            await self._pipeline.setup_retrieval_chain(context_format="json")

    async def _process_question(self, question: str) -> Dict[str, Any]:
        """Process a question through the RAG pipeline."""
        if self._pipeline is None:
            await self._initialize_pipeline()
        return await self._pipeline.run(question)

    async def _chat_callback(
        self,
        contents: str,
        user: str,
        instance: pn.chat.ChatInterface
    ):
        """Callback function for chat interface."""
        try:
            match user:
                case "Assistant" | "System":
                    pass

                case "User":
                    message = MessageWithContext(
                        object=await self._process_question(contents),
                        show_copy_icon=True,
                        show_edit_icon=False,
                        show_timestamp=True,
                        show_reaction_icons=False,
                    )
                    instance.send(value=message, user=None, avatar=None)
                    return

                case _:
                    raise ValueError(f"Invalid user: {user}")

        except Exception as e:
            yield pn.chat.ChatMessage(
                object=f"**Error**: {str(e)}\n\n```python\n{traceback.format_exc()}\n```",
                user="System",
                avatar="⚠️"
            )

    def _create_chat_interface(self) -> pn.chat.ChatInterface:
        """Create the chat interface component."""
        def save_chat_history(instance: pn.chat.ChatInterface, event: Any):
            """Save the chat history to a file."""
            filename = Path(f"chat_history_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.html")
            try:
                instance.save(filename)
            except Exception as e:
                pn.state.notifications.error(f"Error saving chat history: {str(e)}", duration=10000)
            else:
                pn.state.notifications.success(
                    f"Chat history saved successfully to: {filename.resolve()}!",
                    duration=3000
                )

        return ChatWithConfigurableMessages(
            message_kwargs={"show_reaction_icons": False},
            callback=self._chat_callback,
            widgets=pn.chat.ChatAreaInput(
                placeholder="Ask your question!",
                sizing_mode="stretch_width"
            ),
            sizing_mode="stretch_width",
            show_send=True,
            show_stop=True,
            show_rerun=True,
            show_undo=False,
            show_clear=True,
            show_button_name=False,
            avatar="👤",
            height=600,
            button_properties={"save": {"icon": "device-floppy", "callback": save_chat_history}}
        )

    def _create_layout(self) -> pn.template.BootstrapTemplate:
        """Create the main layout."""
        return pn.template.BootstrapTemplate(
            title="RAG Agent Chat Interface",
            main=[
                pn.layout.Card(
                    self._chat_interface,
                    title="Chat",
                    collapsible=False,
                    collapsed=False,
                ),
                pn.layout.Card(
                    self._terminal,
                    title="Terminal",
                    collapsible=True,
                    collapsed=True,
                ),
            ],
            sidebar=[
                pn.layout.Accordion(
                    *self._config_widgets[:-3],
                    active=[0],
                    toggle=True,
                ),
                self._config_widgets[-3],
                pn.layout.Row(
                    pn.layout.HSpacer(),
                    *self._config_widgets[-2:],
                ),
            ],
            sidebar_width=350,
            theme=pn.template.DefaultTheme,
            collapsed_sidebar=True,
        )

    def serve(self, port: int = 8501) -> None:
        """Serve the application."""
        pn.state.onload(self._initialize_pipeline)
        pn.serve(
            self._layout,
            port=port,
            show=False,
            title="RAG Chat Interface",
            allow_websocket_origin=["*"],
        )


def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(description="RAG Chat Interface")
    parser.add_argument("--port", type=int, default=8501, help="Port to serve the application on")
    parser.add_argument(
        "--database-engine",
        type=str,
        default="chroma",
        help="Database engine to use",
        choices=["chroma"]
    )
    args = parser.parse_args()

    match args.database_engine:
        case "chroma":
            config = ChromaRAGPipelineConfig()
        case _:
            raise ValueError(f"Unsupported database engine: {args.db_engine}")

    chat = RAGChatInterface(config=config)
    chat.serve(port=args.port)


if __name__ == "__main__":
    main()
