import datetime
import decimal
import traceback

from easytradesdk import Serializer
from easytradesdk.Const import KlinePeriod
from easytradesdk.entity.Position import Position
from easytradesdk.entity.Slippage import Slippage

from easytradetesting.BackTestConfig import BackTestConfig
from easytradetesting.BackTestContext import BackTestContext
from easytradetesting.BackTestResult import BackTestResult
from easytradetesting.DataFetcher import DataFetcher
from easytradetesting.MySqlDataSource import MySqlDataSource
from easytradetesting.api.MarketApi import LocalMarketApi
from easytradetesting.api.TradeApi import TradeApi


class BackTestExecutor:

    def __init__(self):
        self.dataSource = None
        self.dataFetcher = None
        self.executeInstance = None
        self.backTestContext = None

    def init(self, backTestConfig):

        self.dataSource = MySqlDataSource(host=backTestConfig.dataSourceConfig["host"], port=backTestConfig.dataSourceConfig["port"], user=backTestConfig.dataSourceConfig["user"], password=backTestConfig.dataSourceConfig["password"], database=backTestConfig.dataSourceConfig["database"])
        self.dataFetcher = DataFetcher(serverAddress=backTestConfig.dataFetcherConfig["serverAddress"], apiKey=backTestConfig.dataFetcherConfig["apiKey"], secret=backTestConfig.dataFetcherConfig["secret"], https=backTestConfig.dataFetcherConfig["https"])
        self.executeInstance = backTestConfig.executeClass()
        self.backTestContext = BackTestContext(backTestConfig.backTestCode)
        self.backTestContext.setMarketApi(LocalMarketApi(self.backTestContext))
        self.backTestContext.setTradeApi(TradeApi(self.backTestContext))
        self.backTestContext.setExecutePeriod(backTestConfig.executePeriod)
        self.backTestContext.dataSource = self.dataSource
        self.backTestContext.setStrategyParams(backTestConfig.strategyParams)
        self.dataSource.createBackTestOrderTable()

        if backTestConfig.slippageConfigs:
            if "buy" in backTestConfig.slippageConfigs:
                self.backTestContext.buySlippage = Slippage(backTestConfig.slippageConfigs["buy"]["slippageType"], backTestConfig.slippageConfigs["buy"]["value"])
            if "sell" in backTestConfig.slippageConfigs:
                self.backTestContext.buySlippage = Slippage(backTestConfig.slippageConfigs["sell"]["slippageType"], backTestConfig.slippageConfigs["sell"]["value"])

        self.__fetchKlines(backTestConfig)
        self.__initPosition(backTestConfig)

        if backTestConfig.cleanBackTestOrders:
            self.dataSource.deleteBackTestOrders(backTestConfig.backTestCode)

    def execute(self, backTestConfig: dict):

        _backTestConfig = Serializer.dictToObject(backTestConfig, BackTestConfig)

        self.init(_backTestConfig)

        _amountChangeList = []
        _startTimeStamp = _backTestConfig.startTimeMills
        _endTimeStamp = _backTestConfig.endTimeStampMills
        _executeInstance = self.executeInstance
        _backTestContext = self.backTestContext
        _executePeriod = self.backTestContext.getExecutePeriod()
        _periodMills = BackTestExecutor.resolveExecutePeriodMills(_executePeriod)
        _dataSource = self.dataSource
        _executeInstance.init(_backTestContext)

        while True:

            _positions = _backTestContext.getPositions()
            _startTotalAmount = decimal.Decimal('0')
            _endTotalAmount = decimal.Decimal('0')

            for _position in _positions.values():
                _tc = _position.tc
                _symbol = _position.symbol
                _startKlines = _dataSource.queryKline(_tc, _symbol, KlinePeriod.MIN_1, startTimeMills=_startTimeStamp, limit=1)
                if _startKlines:
                    _price = _startKlines[0].openingPrice
                else:
                    _startKlines = _dataSource.queryLatestKline(_tc, _symbol, KlinePeriod.MIN_1, endTimeMills=_startTimeStamp, limit=1)
                    if not _startKlines:
                        raise Exception("kline to " + str(_startTimeStamp) + " not found")
                    _price = _startKlines[0].closingPrice
                _startTotalAmount = _startTotalAmount + (_position.remainAmount + _position.remainHolding * _price)

            try:
                _executeInstance.executeStopLoss(_backTestContext)
            except Exception as e:
                traceback.print_exc()
            try:
                _executeInstance.executeStopProfit(_backTestContext)
            except Exception as e:
                traceback.print_exc()
            try:
                _backTestContext.setExecutingTimeMills(_startTimeStamp)
                _executeInstance.execute(_backTestContext)

                _positions = _backTestContext.getPositions()
                _endMills = _startTimeStamp + _periodMills - 1

                for _position in _positions.values():
                    _tc = _position.tc
                    _symbol = _position.symbol
                    _endKlines = _dataSource.queryLatestKline(_tc, _symbol, KlinePeriod.MIN_1, endTimeMills=_endMills, limit=1)

                    if _endKlines:
                        _price = _endKlines[0].closingPrice
                    else:
                        raise Exception("kline to " + str(_endMills) + " not found")

                    _endTotalAmount = _endTotalAmount + (_position.remainAmount + _position.remainHolding * _price)

                _amountChangeList.append(
                    {
                        "timeMills": _startTimeStamp,
                        "startTotalAmount": _startTotalAmount, "endTotalAmount": _endTotalAmount, "profitRate": (_endTotalAmount - _startTotalAmount) / _startTotalAmount * 100
                    }
                )

            except Exception as e:
                traceback.print_exc()

            if _startTimeStamp >= _endTimeStamp:
                break
            else:
                _startTimeStamp = _startTimeStamp + _periodMills

        try:
            _executeInstance.destroy(_backTestContext)
        except Exception as e:
            traceback.print_exc()

        # 生产回测结果
        _backTestResult = BackTestResult(
            _backTestConfig.startDate, _backTestConfig.endDate, _backTestConfig.backTestCode, _backTestContext.backTestOrders, _backTestContext.getPositions(), _amountChangeList)

        _backTestResult.calculateProfits()
        _backTestResult.calculateMaxDrawDown()

        return _backTestResult

    def __fetchKlines(self, backTestConfig):

        for _klineConfig in backTestConfig.klineConfigs:

            _tc = _klineConfig["tc"]
            _symbol = _klineConfig["symbol"]
            _klinePeriod = _klineConfig["klinePeriod"]

            self.dataSource.createKlineTable(_tc, _symbol, _klinePeriod)

            print("Start to fetch klines of [" + _tc + ":" + _symbol + ":" + _klinePeriod + "] for " + backTestConfig.backTestCode + " ...")
            _remoteKlineCount = self.dataFetcher.countKline(_tc, _symbol, _klinePeriod, backTestConfig.startTimeMills, backTestConfig.endTimeStampMills)
            _localKlineCount = self.dataSource.countKline(_tc, _symbol, _klinePeriod, backTestConfig.startTimeMills, backTestConfig.endTimeStampMills)
            _startTimeMills = backTestConfig.startTimeMills

            if _remoteKlineCount == _localKlineCount:
                print("Klines of [" + _tc + ":" + _symbol + ":" + _klinePeriod + "] for " + backTestConfig.backTestCode + " already exists ...")
                continue

            while True:

                _klines = self.dataFetcher.fetchKline(_tc, _symbol, _klinePeriod, _startTimeMills, None, 200)
                _d = datetime.datetime.utcfromtimestamp(_startTimeMills / 1000)
                _t = _d.strftime("%Y-%m-%d %H:%M:%S")

                print("Fetched klines of [" + _tc + ":" + _symbol + ":" + _klinePeriod + "] from " + _t + " for " + backTestConfig.backTestCode + " ...")

                _idx = -1

                for i, v in enumerate(_klines):

                    if v.timeMills >= backTestConfig.endTimeStampMills:
                        _idx = i
                        break

                if _idx >= 0:
                    _klines = _klines[0:_idx]

                for _kline in _klines:
                    self.dataSource.saveKline(_tc, _symbol, _klinePeriod, _kline)

                if len(_klines) < 200 or _idx >= 0:
                    break

                _startTimeMills = _klines[len(_klines) - 1].endTimeMills

        return self

    def __initPosition(self, backTestConfig):

        _positions = self.backTestContext.getPositions()

        for _k, _v in backTestConfig.positionConfigs.items():

            if _k not in _positions:
                _position = Position()
                _position.tc = _v["tc"]
                _position.symbol = _v["symbol"]
                _position.initialAmount = decimal.Decimal(str(_v["initialAmount"]))
                _position.initialHolding = decimal.Decimal(str(_v["initialHolding"]))
                _position.remainAmount = decimal.Decimal(str(_v["initialAmount"]))
                _position.remainHolding = decimal.Decimal(str(_v["initialHolding"]))
                _positions[_k] = _position

        for _position in _positions.values():

            _tc = _position.tc
            _symbol = _position.symbol

            _startKlines = self.dataSource.queryKline(_tc, _symbol, KlinePeriod.MIN_1, startTimeMills=backTestConfig.startTimeMills, limit=1)
            _endKlines = self.dataSource.queryLatestKline(_tc, _symbol, KlinePeriod.MIN_1, endTimeMills=backTestConfig.endTimeStampMills, limit=1)

            if not _startKlines:
                raise Exception("first kline from " + str(backTestConfig.startTimeMills) + " for " + backTestConfig.backTestCode + " not found")
            if not _endKlines:
                raise Exception("last kline to " + str(backTestConfig.endTimeStampMills) + " for " + backTestConfig.backTestCode + " not found")

            _position.initialTicker = _startKlines[0].openingPrice
            _position.initialTotalAmount = _position.initialAmount + _position.initialHolding * _position.initialTicker
            _position.lastTicker = _endKlines[0].closingPrice

        return self

    @staticmethod
    def resolveExecutePeriodMills(executePeriod):

        _symbol = executePeriod[-1]
        _cnt = int(executePeriod[0:len(executePeriod) - 1])

        if _symbol == 'm':
            return _cnt * 60 * 1000

        if _symbol == 'h':
            return _cnt * 3600 * 1000

        if _symbol == 'd':
            return _cnt * 24 * 3600 * 1000

        raise Exception("invalid executePeriod")
