import ast
import asyncio
import json
import logging

import socketio
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCIceServer, RTCConfiguration
from aiortc.contrib.media import MediaRelay, MediaBlackhole
from aiortc.mediastreams import MediaStreamError
from aiortc.rtcrtpreceiver import RemoteStreamTrack
from queue import Queue
from gspeerconnection.gsmediarecorder import GSMediaRecorder
from threading import Thread

relay = MediaRelay()


class GSPeerConnectionWatcher:

    @classmethod
    async def create(cls, gsdbs, target, onframe=None, onmessage=None, ontrack=None, ontrackended=None):
        self = GSPeerConnectionWatcher()
        self.rtconfiList = []
        self.sio = socketio.AsyncClient()
        self.gsdbs = gsdbs
        self.onframe = onframe
        self.target = target
        self.onmessage = onmessage
        self.ontrack = ontrack
        self.ontrackended = ontrackended
        self.logger = logging.getLogger(__name__)

        if self.gsdbs.credentials["stunenable"]:
            self.rtconfiList.append(RTCIceServer(self.gsdbs.credentials["stunserver"]))
        if self.gsdbs.credentials["turnenable"]:
            self.rtconfiList.append(RTCIceServer(self.gsdbs.credentials["turnserver"],
                                                 self.gsdbs.credentials["turnuser"],
                                                 self.gsdbs.credentials["turnpw"]))

        @self.sio.event
        async def connect():
            self.logger.info('connection established')

        @self.sio.event
        async def joined(id):
            self.logger.info("joined")

        @self.sio.event
        async def broadcaster(id):
            if len(self.rtconfiList) > 0:
                self.peerConnections = RTCPeerConnection(configuration=RTCConfiguration(self.rtconfiList))
            else:
                self.peerConnections = RTCPeerConnection()

            self.peerConnections.addTransceiver('video', direction='recvonly')

            @self.peerConnections.on("iceconnectionstatechange")
            async def on_iceconnectionstatechange():
                if self.peerConnections.iceConnectionState == "complete":
                    pass

                if self.peerConnections.iceConnectionState == "failed":
                    await self.peerConnections.close()

            @self.peerConnections.on("track")
            async def on_track(track):
                if track.kind == "video":
                    if self.ontrack is not None:
                        self.logger.info("on track received.Recording started")
                        await self.ontrack(self.gsdbs, track, self.target)
                    if self.onframe is not None:
                        self.logger.info("on track received. onframe started")
                        gsmediaconsumer = GSMediaConsumer(self.gsdbs, self.target, self.onframe)
                        gsmediaconsumer.addTrack(track)
                        await gsmediaconsumer.start()

            await self.peerConnections.setLocalDescription(await self.peerConnections.createOffer())
            await self.sio.emit("watcher",
                                {"target": self.target,
                                 "sdp": {"type": self.peerConnections.localDescription.type,
                                         "sdp": self.peerConnections.localDescription.sdp}})

        @self.sio.event
        async def answer(id, description):
            if isinstance(description, dict):
                desc = type('new_dict', (object,), description)
            else:
                desc = type('new_dict', (object,), ast.literal_eval(description))
            await self.peerConnections.setRemoteDescription(desc)

        @self.sio.event
        async def disconnectBroadcaster(id, target):
            await self.ontrackended(self.gsdbs, target)

        if "localhost" in self.gsdbs.credentials["signalserver"]:
            connectURL = f'{self.gsdbs.credentials["signalserver"]}:{str(self.gsdbs.credentials["signalport"])}'
        else:
            connectURL = self.gsdbs.credentials["signalserver"]

        await self.sio.connect(
            f'{connectURL}?gssession={self.gsdbs.cookiejar.get("session")}.{self.gsdbs.cookiejar.get("signature")}{self.gsdbs.credentials["cnode"]}&target={self.target}')
        await self.sio.wait()


class FrameBufferThread(Thread):
    def __init__(self, gsdbs, target, onframe):
        super(FrameBufferThread, self).__init__()
        self.gsdbs = gsdbs
        self.logger = logging.getLogger(__name__)
        self.logger.debug("Starting Buffer")
        self.target = target
        self.onframe = onframe
        self.nal_queue = Queue(3)

    def write(self, frame):
        if not self.nal_queue.full():
            self.nal_queue.put(frame)

    def run(self):
        while True:
            buf = self.nal_queue.get()
            if buf:
                self.onframe(self.gsdbs, self.target, buf)


class GSMediaConsumer:

    def __init__(self, gsdbs, target, onframe):
        self.__tracks = {}
        self.gsdbs = gsdbs
        self.target = target
        self.onframe = onframe
        self.nal_queue = Queue(3)
        self.broadcastthread = FrameBufferThread(self.gsdbs, self.target, self.onframe)
        self.broadcastthread.start()

    def addTrack(self, track):
        if track not in self.__tracks:
            self.__tracks[track] = None

    async def start(self):
        """
        Start discarding media.
        """
        for track, task in self.__tracks.items():
            if task is None:
                # self.startProcessing(task)
                # _thread = threading.Thread(target=self.startProcessing,
                #                            args=(track,))
                # _thread.start()
                self.__tracks[track] = asyncio.ensure_future(
                    self.stream_consume(track, self.gsdbs, self.target, self.onframe))

    async def stream_consume(self, track, gsdbs, target, onframe):
        while True:
            try:
                frame = await track.recv()
                self.broadcastthread.write(frame)
            except MediaStreamError:
                return

    async def stop(self):
        """
        Stop discarding media.
        """
        for task in self.__tracks.values():
            if task is not None:
                task.cancel()
        self.__tracks = {}
