# auto generated by update_py.py
import os
import queue
import time

from concurrent.futures import ThreadPoolExecutor
from google.protobuf.json_format import MessageToDict, ParseDict
from peewee import fn, JOIN

from ..database.sqlite_models import rtn_order_db, rtn_trade_db, position_db, capital_db
from ..database.sqlite_models import RtnOrder, RtnTrade, Position, Capital
from ..pb_msg import message_pb


class BaseHelper:

    def __init__(self, db_path, db_name, model, message):
        self.database = db_name
        self.model = model
        self.message = message

        self.init_db(db_path)

    def init_db(self, db_path):
        self.database.init(db_path)
        with self.database as db:
            db.create_tables([self.model])

    def save_obj(self, obj):
        obj_dict = MessageToDict(obj, preserving_proto_field_name=True, use_integers_for_enums=True)
        self.model.insert(obj_dict).execute()

    def get_objs(self, **kwargs):
        objs = self.model.select()
        for k, v in kwargs.items():
            if v is None:
                continue
            if hasattr(self.model, k):
                objs = objs.where(getattr(self.model, k) == v)
        return [ParseDict(o, self.message(), True) for o in objs.dicts()]


class RtnOrderHelper(BaseHelper):

    def __init__(self, db_path):
        BaseHelper.__init__(self, db_path, rtn_order_db, RtnOrder, message_pb.RtnOrder)

    def save_obj(self, obj: message_pb.RtnOrder):
        obj_dict = MessageToDict(obj, preserving_proto_field_name=True, use_integers_for_enums=True)
        self.model.replace(obj_dict).execute()


class RtnTradeHelper(BaseHelper):

    def __init__(self, db_path):
        BaseHelper.__init__(self, db_path, rtn_trade_db, RtnTrade, message_pb.RtnTrade)


class PositionHelper(BaseHelper):

    def __init__(self, db_path):
        BaseHelper.__init__(self, db_path, position_db, Position, message_pb.GatewayPosition)

    def save_obj(self, obj: message_pb.GatewayPosition):
        obj_dict = MessageToDict(obj, preserving_proto_field_name=True, use_integers_for_enums=True)
        for o in obj_dict['positions']:
            o['pos_id'] = obj_dict['pos_id']

        with self.database.atomic():
            self.model.insert_many(obj_dict['positions']).execute()

    def get_objs(self, **kwargs):
        MaxPosID = self.model.select(fn.MAX(self.model.pos_id).alias('pos_id'), self.model.account_id, self.model.sub_account).group_by(self.model.account_id, self.model.sub_account)
        predicate = ((self.model.pos_id == MaxPosID.c.pos_id) & (self.model.account_id == MaxPosID.c.account_id) & (self.model.sub_account == MaxPosID.c.sub_account))
        objs = self.model.select().join(MaxPosID, JOIN.INNER, on=predicate)
        for o in objs.dicts():
            pass
        # TODO: get latest position
        return None


class CapitalHelper(BaseHelper):

    def __init__(self, db_path):
        BaseHelper.__init__(self, db_path, capital_db, Capital, message_pb.RspAccount)

    def get_objs(self, **kwargs):
        objs = self.model.select(fn.MAX(self.model.req_id), self.model).group_by(self.model.account_id, self.model.sub_account)
        for k, v in kwargs.items():
            if v is None:
                continue
            if hasattr(self.model, k):
                objs = objs.where(getattr(self.model, k) == v)
        return [ParseDict(o, self.message(), True) for o in objs.dicts()]


class DBHelper:

    def __init__(self, db_path):
        if not os.path.exists(db_path):
            os.makedirs(db_path)

        self._to_stop = False

        self._rtn_order_helper = RtnOrderHelper(os.path.join(db_path, 'rtn_order.db'))
        self._rtn_trade_helper = RtnTradeHelper(os.path.join(db_path, 'rtn_trade.db'))
        self._position_helper = PositionHelper(os.path.join(db_path, 'position.db'))
        self._capital_helper = CapitalHelper(os.path.join(db_path, 'capital.db'))

        self._order_info_queue = queue.Queue()
        self._rtn_order_queue = queue.Queue()
        self._rtn_trade_queue = queue.Queue()
        self._position_queue = queue.Queue()
        self._capital_queue = queue.Queue()

        self.thread_pool = ThreadPoolExecutor(10)
        self.thread_pool.submit(self._proc_queue, self._rtn_order_queue, self._rtn_order_helper)
        self.thread_pool.submit(self._proc_queue, self._rtn_trade_queue, self._rtn_trade_helper)
        self.thread_pool.submit(self._proc_queue, self._position_queue, self._position_helper)
        self.thread_pool.submit(self._proc_queue, self._capital_queue, self._capital_helper)

    def _proc_queue(self, proc_queue, proc_helper):
        while not self._to_stop or not proc_queue.empty():
            try:
                obj = proc_queue.get(timeout=0.1)
                proc_helper.save_obj(obj)
            except queue.Empty:
                time.sleep(0.01)

    def save(self, obj):
        if isinstance(obj, message_pb.RtnOrder):
            if obj.order_id:
                self._rtn_order_queue.put(obj)
        elif isinstance(obj, message_pb.RtnTrade):
            if obj.order_id and obj.trade_ref:
                self._rtn_trade_queue.put(obj)
        elif isinstance(obj, message_pb.GatewayPosition):
            if obj.pos_id and obj.positions:
                self._position_queue.put(obj)
        elif isinstance(obj, message_pb.RspAccount):
            if obj.req_id:
                self._capital_queue.put(obj)

    ###################
    # get obj from db #
    ###################

    def get_order_info(self, order_id):
        obj = self._rtn_order_helper.get_objs(order_id=order_id)
        return obj[0] if obj else None

    def get_history_rtn_order(self, account_id=None, sub_account=None):
        return self._rtn_order_helper.get_objs(account_id=account_id, sub_account=sub_account)

    def get_history_rtn_trade(self, account_id=None, sub_account=None):
        return self._rtn_trade_helper.get_objs(account_id=account_id, sub_account=sub_account)

    def get_history_position(self, account_id=None, sub_account=None):
        return self._position_helper.get_objs(account_id=account_id, sub_account=sub_account)

    def get_history_capital(self, account_id=None, sub_account=None):
        return self._capital_helper.get_objs(account_id=account_id, sub_account=sub_account)

    def close(self):
        self._to_stop = True
