from __future__ import annotations

import structlog
from anyio import create_task_group
from fps import Module

from jupyverse_api.app import App
from jupyverse_api.auth import Auth
from jupyverse_api.frontend import FrontendConfig
from jupyverse_api.kernels import Kernels, KernelsConfig
from jupyverse_api.main import Lifespan
from jupyverse_api.yjs import Yjs

from .routes import _Kernels

log = structlog.get_logger()


class KernelsModule(Module):
    def __init__(self, name: str, **kwargs):
        super().__init__(name)
        self.kernels_config = KernelsConfig(**kwargs)

    async def prepare(self) -> None:
        self.put(self.kernels_config, KernelsConfig)

        app = await self.get(App)
        auth = await self.get(Auth)  # type: ignore[type-abstract]
        frontend_config = await self.get(FrontendConfig)
        lifespan = await self.get(Lifespan)
        yjs = (
            await self.get(Yjs)  # type: ignore[type-abstract]
            if self.kernels_config.require_yjs
            else None
        )

        self.kernels = _Kernels(app, self.kernels_config, auth, frontend_config, yjs, lifespan)
        self.put(self.kernels, Kernels)

        async with create_task_group() as tg:
            tg.start_soon(self.kernels.start)
            self.done()

    async def stop(self) -> None:
        await self.kernels.stop()
