import grpc

# classes generated by grpc
from .generated import eval_server_pb2
from .generated import eval_server_pb2_grpc

# other imports
import numpy as np 
import time
from time import perf_counter
from datetime import datetime

# multiprocessing 
import multiprocessing as mp
from multiprocessing import shared_memory
from threading import Thread


def get_grpc_result(bboxes, scores, labels):
    if len(bboxes) != len(scores) or len(bboxes) != len(labels):
        raise ValueError('number of bboxes, labels, bbox_scores must be same')
    res = eval_server_pb2.Result()
    for i in range(len(bboxes)):
        bbox = eval_server_pb2.Bbox()
        bbox.x1 = bboxes[i][0]
        bbox.y1 = bboxes[i][1]
        bbox.x2 = bboxes[i][2]
        bbox.y2 = bboxes[i][3]

        res.bboxes.append(bbox)
        res.bbox_scores.append(scores[i])
        res.labels.append(labels[i])
    return res

# receive input fidx streamed by server, store them in a list
def receive_stream(seq, latest_fidx, fid_ptr_dict, is_stream_ready, stream_start_time, config, verbose=False):
    if verbose:
        print("EvalClient (", datetime.now(), "): ", "Requesting stream for sequence ", seq)
    channel = grpc.insecure_channel(config['loopback_ip'] + ":" + str(config['image_service_port']))
    stub = eval_server_pb2_grpc.ImageServiceStub(channel)
    stream_request = eval_server_pb2.String(value=seq)
    send_times = []

    # receive input stream
    for i, response in enumerate(stub.GetImageStream(stream_request)):
        if i == 0:
            stream_start_time.value = perf_counter()
            if verbose:
                print("EvalClient (", datetime.now(), "): ", "Receiving stream for sequence ", seq) 
        if response.end_marker:
            latest_fidx.value = -1
            break
        is_stream_ready.clear()
        latest_fidx.value = response.fid
        fid_ptr_dict[response.fid] = (response.start_ptr, response.end_ptr)
        if response.fid >= 0:
            is_stream_ready.set()
        send_times.append(perf_counter() - response.timestamp)
    
    # print("EvalClient (", datetime.now(), "): ", "Mean sending time = ", np.mean(send_times), "s, stdev = ", np.std(send_times))
    # print("EvalClient (", datetime.now(), "): ", "Max/min sending time = ", np.max(send_times), np.min(send_times))
    # print("EvalClient (", datetime.now(), "): ", "Histogram = ", np.histogram(send_times))

    channel.close()

class EvalClient:

    def __init__(self, config, state=None, verbose=False):
        self.img_width, self.img_height = 1920, 1200

        if state is None:
            mp.set_start_method('spawn')
            self.latest_fidx = mp.Value('i', -1, lock=True)
            self.is_stream_ready = mp.Event()
            self.fid_ptr_dict = mp.Manager().dict()
            self.stream_start_time = mp.Value('d', 0.0, lock=True)
        else:
            self.latest_fidx = state[0]
            self.is_stream_ready = state[1]
            self.fid_ptr_dict = state[2]

        self.verbose = verbose
        # create image receiver stub
        self.channel = grpc.insecure_channel(config['loopback_ip'] + ":" + str(config['image_service_port']))
        self.config = config
        self.stub = eval_server_pb2_grpc.ImageServiceStub(self.channel)
        response = self.stub.GetShm(eval_server_pb2.Empty())
        self.existing_shm = shared_memory.SharedMemory(name=response.value)
        self.channel.close()

        # create result sender stub
        self.result_channel = grpc.insecure_channel(config['loopback_ip'] + ":" + str(config['result_service_port']))
        self.result_stub = eval_server_pb2_grpc.ResultServiceStub(self.result_channel)

        self.is_stream_ready.clear()
        self.stream_process = None

        self.latest_grpc_result = None
        self.grpc_result_ready = mp.Event()
        self.grpc_result_ready.clear()
        self.sequence_ended = False
        self.res_thread = None

    def get_state(self):
        return (self.latest_fidx, self.is_stream_ready, self.fid_ptr_dict)

    def close(self, results_file='results.json'):
        self.result_channel.close()
        self.result_stub.GenResults(eval_server_pb2.String(value=results_file))

    def result_stream_iterator(self):
        while 1:
            self.grpc_result_ready.wait()
            if self.sequence_ended:
                break
            yield self.latest_grpc_result
            self.grpc_result_ready.clear()

    def send_result_stream(self):
        result_iterator = self.result_stream_iterator()
        response = self.result_stub.PutResultStream(result_iterator)

    def stop_stream(self):
        self.stream_process.join()
        self.result_stub.FinishSequence(eval_server_pb2.Empty())
        self.is_stream_ready.clear()

        self.stream_process = None
        self.sequence_ended = True
        self.grpc_result_ready.set()
        self.res_thread.join()

    def request_stream(self, seq):
        self.sequence_ended = False
        self.grpc_result_ready.clear()
        self.res_thread = Thread(target=self.send_result_stream)
        self.res_thread.start()
        
        # fid_stream_receiver as processs
        self.stream_process = mp.Process(target=receive_stream, args=(seq, self.latest_fidx, self.fid_ptr_dict, self.is_stream_ready, self.stream_start_time, self.config, self.verbose))
        self.stream_process.start()

    def get_latest_fidx(self):
        self.is_stream_ready.wait()
        return self.latest_fidx.value

    def get_frame(self, fid=None, ptr=False):
        if fid is not None and fid < 0:
            raise TypeError(f"fid must be non-negative")
        if fid is None:
            fid = self.get_latest_fidx()
            if fid == -1:
                return None, None
        elif fid not in self.fid_ptr_dict:
            raise KeyError(f"frame not available yet")
        start_ptr, end_ptr = self.fid_ptr_dict[fid]
        if ptr:
            return fid, int(start_ptr/(self.img_height*self.img_width*3))
        return fid, np.ndarray((self.img_height, self.img_width, 3), dtype=np.uint8, buffer=self.existing_shm.buf[start_ptr:end_ptr])
    
    def send_result_async(self, bboxes, bbox_scores, labels):
        timestamp = perf_counter()
        grpc_result = get_grpc_result(bboxes, bbox_scores, labels)
        grpc_result.timestamp = timestamp
        self.latest_grpc_result = grpc_result
        self.grpc_result_ready.set()

    def send_result_to_server(self, bboxes, bbox_scores, labels):
        a = Thread(target=self.send_result_async, args=(bboxes, bbox_scores, labels))
        a.start()

    def get_frame_buf(self):
        return self.existing_shm

    def get_stream_start_time(self):
        self.is_stream_ready.wait()
        return self.stream_start_time.value