# auto generated by update_py.py

import os
import time
import threading
import signal
import zmq
import json
from collections import deque

import tlclient.linker.message_comm as message
from tlclient.linker.frame import Frame, FrameHeaderStatus, MemBuffer
from tlclient.linker.timer import Timer
from tlclient.linker.constant import MsgType, CommType, FistType, HeartBeatStatus, ANY_FIST_NAME, DEFAULT_REQ_ID
from tlclient.linker.utility import bytify
from tlclient.linker.logger import Logger


class Fist:

    ZMQ_BATCH_PUB_BUFFER_SIZE = 102400
    POLL_WAIT_TIME = 1000
    REQ_WAIT_TIME = 2000

    def __init__(self, fist_name, fist_type, env_name, addr=None):
        self.logger = Logger.get_logger(fist_name)
        self.fist_name = fist_name
        self.fist_type = fist_type
        self.env_name = env_name
        self._context = zmq.Context()
        # internal
        self.source_id = None
        self.pub_sock = None
        self.pull_sock = None
        self.rep_sock = None
        self.sub_master_sock = None
        self.req_master_sock = None
        self.master_addr = None
        self.influx_db_client = None
        self.sub_poller = zmq.Poller()
        self.buffer_size = 0
        self.push_socks = {}
        self.req_socks = {}
        self.fist_req_addrs = {}
        self.cmd_lock = threading.Lock()
        self.cmd_pending_rids = set()
        self.received_signal = None
        self.heart_beat_msg = message.MsgHeartBeat()
        self.heart_beat_msg.fist_type = self.fist_type
        self.heart_beat_msg.fist_name = bytify(self.fist_name)
        self.heart_beat_msg.desc_name = bytify('')
        self.heart_beat_msg.hb_status = HeartBeatStatus.NOT_AVAILABLE
        self.mem_buffer = None # buffer used when batch
        # thread ready
        self.sub_ready = False
        self.pull_ready = False
        self.rep_ready = False
        self.sub_master_ready = False
        if addr:
            self.set_master_addr(addr)

    def is_stopped(self):
        return self.received_signal is not None

    def _recv_req(self, socket, timeout):
        poll = zmq.Poller()
        poll.register(socket, zmq.POLLIN)
        sockets = dict(poll.poll(timeout))
        if socket in sockets:
            return socket.recv()
        else:
            self.logger.error('[recv] failed to receive data')
            return None

    def set_master_addr(self, addr):
        self.master_addr = addr

    def req_master(self, frame):
        self.req_master_sock.send(frame.buf)
        ret = self._recv_req(self.req_master_sock, Fist.REQ_WAIT_TIME)
        if ret is None:
            raise Exception("Timeout in req_master!")
        return Frame(ret)

    def create_fist(self):
        req = message.ReqFistCreate()
        req.fist_name = bytify(self.fist_name)
        req.fist_type = self.fist_type
        req.env_name = bytify(self.env_name)
        req.pid = os.getpid()
        self.reg_req_master()
        f = Frame()
        f.set_msg_type(MsgType.REQ_FIST_CREATE)
        f.set_nano(Timer.nano())
        f.set_data(req)
        ret = self.req_master(f)
        ret_obj = ret.get_obj(message.RspFistCreate)
        if ret_obj.source <= 0:
            raise Exception("failed to create fist: {}".format(ret_obj.err_msg))
        self.source_id = ret_obj.source
        self.logger.debug('[create_fist] (source_id){}'.format(self.source_id))
        self.reg_sub_master()

    def set_io_thread_num(self, io_threads=1):
        self._context.setsockopt(zmq.IO_THREADS, io_threads)

    def set_buffer_size(self, buffer_size):
        assert isinstance(buffer_size, int) and buffer_size >= 0, 'buffer size should be a non-negative integer'
        self.buffer_size = buffer_size

    def set_pub(self):
        addr = self.get_bind_addr(CommType.Zmq_PUB)
        self.pub_sock = self._context.socket(zmq.PUB)
        self.pub_sock.setsockopt(zmq.SNDHWM, self.buffer_size)
        #self.pub_sock.setsockopt(zmq.RCVHWM, self.buffer_size)
        self.pub_sock.bind(addr)
        self.logger.debug('[set_pub] (addr){}'.format(addr))
        return True

    def set_pull(self):
        addr = self.get_bind_addr(CommType.Zmq_PULL)
        self.pull_sock = self._context.socket(zmq.SUB)
        self.pull_sock.setsockopt(zmq.SUBSCRIBE, b'')
        #self.pull_sock.setsockopt(zmq.SNDHWM, self.buffer_size)
        self.pull_sock.setsockopt(zmq.RCVHWM, self.buffer_size)
        self.pull_sock.bind(addr)
        self.logger.debug('[set_pull] (addr){}'.format(addr))
        return True

    def set_rep(self):
        addr = self.get_bind_addr(CommType.Zmq_REP)
        self.rep_sock = self._context.socket(zmq.REP)
        self.rep_sock.setsockopt(zmq.SNDHWM, self.buffer_size)
        self.rep_sock.setsockopt(zmq.RCVHWM, self.buffer_size)
        self.rep_sock.bind(addr)
        self.logger.debug('[set_rep] (addr){}'.format(addr))
        return True

    def reg_sub(self, other_fist_name):
        addr = self.get_connect_addr(other_fist_name, CommType.Zmq_PUB)
        sub_sock = self._context.socket(zmq.SUB)
        sub_sock.setsockopt(zmq.SUBSCRIBE, b'')
        #sub_sock.setsockopt(zmq.SNDHWM, self.buffer_size)
        sub_sock.setsockopt(zmq.RCVHWM, self.buffer_size)
        sub_sock.connect(addr)
        self.sub_poller.register(sub_sock, zmq.POLLIN)
        self.logger.debug('[reg_sub] (other_fist_name){} (addr){}'.format(other_fist_name, addr))
        return True

    def reg_sub_master(self):
        addr = self.get_connect_addr('master', CommType.Zmq_PUB)
        self.sub_master_sock = self._context.socket(zmq.SUB)
        self.sub_master_sock.setsockopt(zmq.SUBSCRIBE, b'')
        #self.sub_master_sock.setsockopt(zmq.SNDHWM, self.buffer_size)
        self.sub_master_sock.setsockopt(zmq.RCVHWM, self.buffer_size)
        self.sub_master_sock.connect(addr)
        self.logger.debug('[reg_sub_master] (addr){}'.format(addr))
        return True

    def reg_req_master(self):
        self.req_master_sock = self._context.socket(zmq.REQ)
        self.req_master_sock.setsockopt(zmq.ROUTING_ID, bytify(self.fist_name))
        self.req_master_sock.setsockopt(zmq.SNDHWM, self.buffer_size)
        self.req_master_sock.setsockopt(zmq.RCVHWM, self.buffer_size)
        self.req_master_sock.connect(self.master_addr)
        self.logger.debug('[reg_req_master] (master_addr){}'.format(self.master_addr))
        return True

    def reg_push(self, other_fist_name):
        addr = self.get_connect_addr(other_fist_name, CommType.Zmq_PULL)
        push_sock = self._context.socket(zmq.PUB)
        push_sock.setsockopt(zmq.SNDHWM, self.buffer_size)
        #push_sock.setsockopt(zmq.RCVHWM, self.buffer_size)
        push_sock.connect(addr)
        self.push_socks[other_fist_name] = push_sock
        self.logger.debug('[reg_push] (other_fist_name){} (addr){}'.format(other_fist_name, addr))
        return True

    def reg_req(self, other_fist_name):
        addr = self.get_connect_addr(other_fist_name, CommType.Zmq_REP)
        self.fist_req_addrs[other_fist_name] = addr
        req_sock = self._context.socket(zmq.REQ)
        req_sock.setsockopt(zmq.SNDHWM, self.buffer_size)
        req_sock.setsockopt(zmq.RCVHWM, self.buffer_size)
        req_sock.connect(addr)
        self.req_socks[other_fist_name] = req_sock
        self.logger.debug('[reg_req] (other_fist_name){} (addr){}'.format(other_fist_name, addr))
        return True

    def pub_f(self, frame):
        self.pub_sock.send(frame.buf)

    def batch_pub_frame(self, sock, frame_deque):
        if self.mem_buffer is None:
            self.mem_buffer = MemBuffer(self.ZMQ_BATCH_PUB_BUFFER_SIZE)
        if len(frame_deque) == 0:
            return 0
        frame_num = 0
        idx = 0
        while len(frame_deque) > 0:
            if idx + frame_deque[0].get_length() >= self.mem_buffer.get_length():
                break
            f = frame_deque.popleft()
            idx = self.mem_buffer.append(idx, f)
            frame_num += 1
        if frame_num == 0:
            f = frame_deque.popleft()
            self.logger.warning('[batch_zmq_pub] a big frame is sent (size){}'.format(f.get_length()))
            sock.send(f.buf)
            return 1
        else:
            self.mem_buffer.finalize()
            sock.send(self.mem_buffer._buffer[:idx])
            return frame_num

    def batch_pub(self, frame_deque):
        return self.batch_pub_frame(self.pub_sock, frame_deque)

    def batch_push(self, fist_name, frame_deque):
        ps_sock = self.push_socks[fist_name]
        return self.batch_pub_frame(ps_sock, frame_deque)

    def pub(self, obj, msg_type, req_id, err_id=0):
        f = Frame()
        f.set_status(FrameHeaderStatus.NORMAL)
        f.set_msg_type(msg_type)
        f.set_req_id(req_id)
        f.set_err_id(err_id)
        f.set_nano(Timer.nano())
        f.set_source(self.get_source_id())
        f.set_data(obj)
        self.pub_f(f)

    def push(self, fist_name, obj, msg_type, req_id, err_id=0):
        ps_sock = self.push_socks[fist_name]
        f = Frame()
        f.set_status(FrameHeaderStatus.NORMAL)
        f.set_msg_type(msg_type)
        f.set_req_id(req_id)
        f.set_err_id(err_id)
        f.set_nano(Timer.nano())
        f.set_source(self.get_source_id())
        f.set_data(obj)
        ps_sock.send(f.buf)

    def req(self, fist_name, obj, msg_type, req_id, err_id=0):
        req_sock = self.req_socks[fist_name]
        f = Frame()
        f.set_status(FrameHeaderStatus.NORMAL)
        f.set_msg_type(msg_type)
        f.set_req_id(req_id)
        f.set_err_id(err_id)
        f.set_nano(Timer.nano())
        f.set_source(self.get_source_id())
        f.set_data(obj)
        try:
            req_sock.send(f.buf)
        except zmq.error.ZMQError as e:
            # double req problem, reset req socket
            if e.errno == 156384763:
                req_sock = self._context.socket(zmq.REQ)
                req_sock.connect(self.fist_req_addrs[fist_name])
                req_sock.send(f.buf)
                self.req_socks[fist_name] = req_sock
            # all other errs, reraise err
            else:
                raise e
        ret = self._recv_req(req_sock, Fist.REQ_WAIT_TIME)
        if ret is None:
            raise Exception("Timeout in req {}!".format(fist_name))
        return Frame(ret)

    def notify(self, title, content, notification_type):
        req = message.ReqNotify()
        req.title = bytify(title)
        req.message = bytify(content)
        req.type = notification_type
        f = Frame()
        f.set_msg_type(MsgType.REQ_NOTIFY)
        f.set_source(self.get_source_id())
        f.set_nano(Timer.nano())
        f.set_data(req)
        ret = self.req_master(f)
        rsp_obj = ret.get_obj(message.RspNotify)
        if not rsp_obj.accepted:
            self.logger.error('[notify] failed, (err_msg){}'.format(rsp_obj.err_msg))
        return rsp_obj.accepted

    def send_req_command(self, target_fist_name, content):
        req = message.ReqCommand()
        req.request_id = -1
        req.target_fist_type = FistType.NOT_AVAILABLE
        req.target_fist_name = target_fist_name
        req.from_fist_name = self.fist_name
        req.content = content
        f = Frame()
        f.set_msg_type(MsgType.CMD_REQUEST)
        f.set_source(self.get_source_id())
        f.set_nano(Timer.nano())
        f.set_string(json.dumps(req.to_dict()))
        with self.cmd_lock:
            ret = self.req_master(f)
            if not ret:
                self.logger.error('[cmd] failed')
                return -1
            else:
                rid = ret.get_req_id()
                self.cmd_pending_rids.add(rid)
                return rid

    def send_rsp_command(self, request_id, content):
        rsp = message.RspCommand()
        rsp.request_id = request_id
        rsp.fist_name = self.fist_name
        rsp.content = content
        f = Frame()
        f.set_msg_type(MsgType.CMD_RESPONSE)
        f.set_source(self.get_source_id())
        f.set_nano(Timer.nano())
        f.set_string(json.dumps(rsp.to_dict()))
        ret = self.req_master(f)
        return ret != None

    def on_pub_frame(self, f):
        pass

    def on_push_frame(self, f):
        pass

    def on_req_frame(self, f):
        return Frame()

    def on_req_command(self, request_id, from_fist_name, content):
        pass

    def on_rsp_command(self, request_id, from_fist_name, content):
        pass

    def on_pub_master(self, f):
        msg_type = f.get_msg_type()
        if msg_type == MsgType.CMD_SUICIDE:
            suicide_req = f.get_obj(message.ReqFistSuicide)
            fist_name = suicide_req.fist_name.decode()
            if fist_name in [ANY_FIST_NAME, self.fist_name]:
                self.stop()
        elif msg_type == MsgType.CMD_REQUEST:
            s = f.get_string()
            req = message.ReqCommand(json.loads(s))
            if req.target_fist_type in [None, FistType.NOT_AVAILABLE, self.fist_type] and req.target_fist_name in [ANY_FIST_NAME, self.fist_name]:
                self.on_req_command(req.request_id, req.from_fist_name, req.content)
        elif msg_type == MsgType.CMD_RESPONSE:
            s = f.get_string()
            rsp = message.RspCommand(json.loads(s))
            with self.cmd_lock:
                if rsp.request_id in self.cmd_pending_rids:
                    self.cmd_pending_rids.discard(rsp.request_id)
                    self.on_rsp_command(rsp.request_id, rsp.fist_name, rsp.content)

    def on_close(self):
        pass

    def start(self):
        signal.signal(signal.SIGTERM, self.signal_handler)
        signal.signal(signal.SIGINT, self.signal_handler)
        self.sub_ready = False
        self.pull_ready = False
        self.rep_ready = False
        self.sub_master_ready = False
        ts = [
            threading.Thread(target=self.run_sub),
            threading.Thread(target=self.run_pull),
            threading.Thread(target=self.run_rep),
            threading.Thread(target=self.run_sub_master),
        ]
        for t in ts:
            t.setDaemon(True)
            t.start()
        while not (self.sub_ready and self.pull_ready and self.rep_ready and self.sub_master_ready):
            time.sleep(0.01)

    def stop(self, signum=signal.SIGTERM):
        self.received_signal = signum
        self.on_close()

    def join(self):
        while True:
            time.sleep(0.01)
            if self.is_stopped():
                self.logger.debug('[main_thread] ended (sig){}'.format(self.received_signal))
                break

    def set_hb_status(self, status=HeartBeatStatus.NOT_AVAILABLE):
        self.heart_beat_msg.hb_status = status

    def start_heart_beat(self, fist_to_push, sec_interval=5):
        ts = threading.Thread(target=self._keep_heart_beat, args=[fist_to_push, sec_interval])
        ts.setDaemon(True)
        ts.start()

    def _keep_heart_beat(self, fist_to_push, sec_interval):
        self.logger.info('[hb] started! (f){} (sec){}'.format(fist_to_push, sec_interval))
        while not self.is_stopped():
            self.push(fist_to_push, self.heart_beat_msg, MsgType.FIST_HEART_BEAT, DEFAULT_REQ_ID)
            time.sleep(sec_interval)

    def signal_handler(self, signum=None, frame=None):
        if self.received_signal is None:
            self.stop(signum=signum)

    def get_source_id(self):
        return self.source_id

    ######################
    # internal functions #
    ######################
    def get_bind_addr(self, comm_type):
        req = message.ReqFistSet()
        req.comm_type = comm_type
        f = Frame()
        f.set_msg_type(MsgType.REQ_FIST_SET)
        f.set_source(self.get_source_id())
        f.set_nano(Timer.nano())
        f.set_data(req)
        ret = self.req_master(f)
        rsp_obj = ret.get_obj(message.RspFistSet)
        if rsp_obj.is_allowed:
            return rsp_obj.addr
        else:
            raise Exception("Rejected: (fist_name){} (comm_type){} (err_msg){}".format(self.fist_name, comm_type, rsp_obj.err_msg))

    def get_connect_addr(self, fist_name, comm_type):
        req = message.ReqFistReg()
        req.fist_name = bytify(fist_name)
        req.comm_type = comm_type
        f = Frame()
        f.set_msg_type(MsgType.REQ_FIST_REG)
        f.set_source(self.get_source_id())
        f.set_nano(Timer.nano())
        f.set_data(req)
        ret = self.req_master(f)
        rsp_obj = ret.get_obj(message.RspFistReg)
        # WARNING!! DO NOT use the addr if it's not connectable
        if not rsp_obj.input_registered or not rsp_obj.input_connectable:
            self.logger.warning('[get_connect_addr] rejected by master (fist_name){} (comm_type){} (reged){} (connectable){}'.format(fist_name, comm_type, rsp_obj.input_registered, rsp_obj.input_connectable))
            return None
        else:
            return rsp_obj.addr

    def run_sub(self):
        self.sub_ready = True
        while not self.is_stopped():
            try:
                socks = dict(self.sub_poller.poll(Fist.POLL_WAIT_TIME))
                for sock in socks:
                    ret = sock.recv()
                    f = Frame(ret)
                    self.on_pub_frame(f)
                    idx = f.get_length()
                    while f.get_status() == FrameHeaderStatus.HAS_NEXT:
                        cur = ret[idx:]
                        f = Frame(cur)
                        self.on_pub_frame(f)
                        idx += f.get_length()
            except Exception as e:
                self.logger.exception('run_sub failed!')
                raise e
        self.logger.debug('[sub_thread] ended (sig){}'.format(self.received_signal))

    def run_sub_master(self):
        self.sub_master_ready = True
        if self.sub_master_sock:
            while not self.is_stopped():
                try:
                    ret = self.sub_master_sock.recv()
                    f = Frame(ret)
                    self.on_pub_master(f)
                    idx = f.get_length()
                    while f.get_status() == FrameHeaderStatus.HAS_NEXT:
                        cur = ret[idx:]
                        f = Frame(cur)
                        self.on_pub_master(f)
                        idx += f.get_length()
                except Exception as e:
                    self.logger.exception('run_sub_master failed!')
                    raise e
            self.logger.debug('[sub_master_thread] ended (sig){}'.format(self.received_signal))

    def run_pull(self):
        self.pull_ready = True
        if self.pull_sock:
            while not self.is_stopped():
                try:
                    ret = self.pull_sock.recv()
                    f = Frame(ret)
                    self.on_push_frame(f)
                    idx = f.get_length()
                    while f.get_status() == FrameHeaderStatus.HAS_NEXT:
                        cur = ret[idx:]
                        f = Frame(cur)
                        self.on_push_frame(f)
                        idx += f.get_length()
                except Exception as e:
                    self.logger.exception('run_pull failed!')
                    raise e
            self.logger.debug('[pull_thread] ended (sig){}'.format(self.received_signal))

    def run_rep(self):
        self.rep_ready = True
        if self.rep_sock:
            while not self.is_stopped():
                try:
                    ret = self.rep_sock.recv()
                    rsp_f = self.on_req_frame(f)
                    self.rep_sock.send(rsp_f.buf)
                except Exception as e:
                    self.logger.exception('run_rep failed!')
                    raise e
            self.logger.debug('[rep_thread] ended (sig){}'.format(self.received_signal))

    # db setting
    def set_influxdb(self, host='localhost', port=8086, user='', passwd='', db='traders_link'):
        from influxdb import InfluxDBClient
        self.influx_db_client = InfluxDBClient(host, port, user, passwd, db)

    def write_points(self, points):
        if self.influx_db_client is None:
            return False
        if not isinstance(points, list):
            points = [points]
        self.logger.debug('[write_points] (points){}'.format(points))
        return self.influx_db_client.write_points(points)

    def query(self, sql):
        if self.influx_db_client is None:
            return None
        return self.influx_db_client.query(sql)
