#!/usr/bin/env python

import asyncio
import getopt
import os
import pprint
import sys
import traceback
import uuid
import toml
import importlib.util
from datetime import datetime, timedelta
from typing import List, Dict

import alpaca_trade_api as tradeapi
import pygit2
import pytz
from requests.exceptions import HTTPError

from liualgotrader.common import config, market_data, trading_data
from liualgotrader.common.database import create_db_connection
from liualgotrader.common.decorators import timeit
from liualgotrader.common.tlog import tlog
from liualgotrader.fincalcs.vwap import add_daily_vwap
from liualgotrader.models.algo_run import AlgoRun
from liualgotrader.models.new_trades import NewTrade
from liualgotrader.models.trending_tickers import TrendingTickers
from liualgotrader.strategies.momentum_long import MomentumLong
from liualgotrader.strategies.base import Strategy, StrategyType


def get_batch_list():
    @timeit
    async def get_batch_list_worker():
        await create_db_connection()
        data = await AlgoRun.get_batches()
        pp = pprint.PrettyPrinter(indent=4)
        pp.pprint(data)

    try:
        if not asyncio.get_event_loop().is_closed():
            asyncio.get_event_loop().close()
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(asyncio.new_event_loop())
        loop.run_until_complete(get_batch_list_worker())
    except KeyboardInterrupt:
        tlog("get_batch_list() - Caught KeyboardInterrupt")
    except Exception as e:
        tlog(
            f"get_batch_list() - exception of type {type(e).__name__} with args {e.args}"
        )
        traceback.print_exc()


"""
starting
"""


def show_usage():
    print(
        f"usage: {sys.argv[0]} -d SYMBOL -v --batch-list --version --debug-symbol SYMBOL\n"
    )
    print("-v, --version\t\tDetailed version details")
    print(
        "--batch-list\tDisplay list of trading sessions, list limited to last 30 days"
    )
    print(
        "--debug-symbol\tWrite verbose debug information for symbol SYMBOL during back-testing"
    )


def show_version(filename: str, version: str) -> None:
    """Display welcome message"""
    print(f"filename:{filename}\ngit version:{version}\n")


def backtest(batch_id: str, debug_symbols: List[str] = None, conf_dict: Dict = None) -> None:
    data_api: tradeapi = tradeapi.REST(
        base_url=config.prod_base_url,
        key_id=config.prod_api_key_id,
        secret_key=config.prod_api_secret,
    )
    portfolio_value: float = 100000.0
    uid = str(uuid.uuid4())

    async def backtest_run(
        start: datetime, duration: timedelta, ref_run_id: int
    ) -> None:
        @timeit
        async def backtest_symbol(symbol: str) -> None:
            est = pytz.timezone("America/New_York")
            start_time = pytz.utc.localize(start).astimezone(est)
            if start_time.second > 0:
                start_time = start_time.replace(second=0, microsecond=0)
            print(
                f"--> back-testing {symbol} from {str(start_time)} duration {duration}"
            )
            if debug_symbols and symbol in debug_symbols:
                print("--> using DEBUG mode")

            # load historical data
            try:
                symbol_data = data_api.polygon.historic_agg_v2(
                    symbol,
                    1,
                    "minute",
                    _from=str(start_time - timedelta(days=8)),
                    to=str(start_time + timedelta(days=1)),
                    limit=10000,
                ).df
            except HTTPError as e:
                tlog(f"Received HTTP error {e} for {symbol}")
                return

            if len(symbol_data) < 100:
                tlog(f"not enough data-points  for {symbol}")
                return

            add_daily_vwap(
                symbol_data, debug=debug_symbols and symbol in debug_symbols
            )

            market_data.minute_history[symbol] = symbol_data
            print(
                f"loaded {len(market_data.minute_history[symbol].index)} agg data points"
            )

            position: int = 0
            minute_index = symbol_data["close"].index.get_loc(
                start_time, method="nearest"
            )
            new_now = symbol_data.index[minute_index]
            print(f"start time with data {new_now}")
            price = 0.0
            last_run_id = None
            # start_time + duration
            while (
                new_now < config.market_close
                and minute_index < symbol_data.index.size - 1
            ):
                if symbol_data.index[minute_index] != new_now:
                    print(
                        "mismatch!", symbol_data.index[minute_index], new_now
                    )
                    print(
                        symbol_data["close"][
                            minute_index - 10 : minute_index + 1
                        ]
                    )
                    raise Exception()

                price = symbol_data["close"][minute_index]
                for strategy in trading_data.strategies:
                    do, what = await strategy.run(
                        symbol,
                        position,
                        symbol_data[: minute_index + 1],
                        new_now,
                        portfolio_value,
                        debug=debug_symbols and symbol in debug_symbols,  # type: ignore
                        backtesting=True,
                    )
                    if do:
                        if what["side"] == "buy":
                            position += int(float(what["qty"]))
                            trading_data.last_used_strategy[symbol] = strategy
                            trading_data.buy_time[symbol] = new_now.replace(
                                second=0, microsecond=0
                            )
                        else:
                            position -= int(float(what["qty"]))

                        db_trade = NewTrade(
                            algo_run_id=strategy.algo_run.run_id,
                            symbol=symbol,
                            qty=int(float(what["qty"])),
                            operation=what["side"],
                            price=price,
                            indicators=trading_data.buy_indicators[symbol]
                            if what["side"] == "buy"
                            else trading_data.sell_indicators[symbol],
                        )

                        await db_trade.save(
                            config.db_conn_pool,
                            str(new_now),
                            trading_data.stop_prices[symbol],
                            trading_data.target_prices[symbol],
                        )
                        print(what)

                        if what["side"] == "buy":
                            await strategy.buy_callback(symbol, price, int(float(what["qty"])))
                            break
                        elif what["side"] == "sell":
                            await strategy.sell_callback(symbol, price, int(float(what["qty"])))
                            break
                    last_run_id = strategy.algo_run.run_id

                minute_index += 1
                new_now = symbol_data.index[minute_index]

            if position > 0:
                if trading_data.last_used_strategy[symbol].type == StrategyType.DAY_TRADE:
                    print(f"liquidate {position}")
                    db_trade = NewTrade(
                        algo_run_id=last_run_id,  # type: ignore
                        symbol=symbol,
                        qty=int(position),
                        operation="sell",
                        price=price,
                        indicators={"liquidate": 1},
                    )
                    await db_trade.save(config.db_conn_pool, str(symbol_data.index[minute_index-1]))

        symbols = await TrendingTickers.load(batch_id)
        print(f"loaded {len(symbols)}:\n {symbols}")

        if len(symbols) > 0:
            est = pytz.timezone("America/New_York")
            start_time = pytz.utc.localize(start).astimezone(est)
            config.market_open = start_time.replace(
                hour=9, minute=30, second=0, microsecond=0
            )
            config.market_close = start_time.replace(
                hour=16, minute=0, second=0, microsecond=0
            )
            print(f"market_open{config.market_open}")
            strategy_types = []
            print("+=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=+")
            for strategy in conf_dict["strategies"]:
                strategy_name = list(strategy.keys())[0]
                strategy_details = strategy[strategy_name]
                if strategy_name == "MomentumLong":
                    tlog(f"strategy {strategy_name} selected")
                    strategy_types += [(MomentumLong, strategy_details)]
                else:
                    tlog(f"custom strategy {strategy_name} selected")

                    try:
                        spec = importlib.util.spec_from_file_location(
                            "module.name", strategy_details["filename"]
                        )
                        custom_strategy_module = importlib.util.module_from_spec(spec)
                        spec.loader.exec_module(  # type: ignore
                            custom_strategy_module
                        )
                        class_name = list(strategy.keys())[0]
                        custom_strategy = getattr(custom_strategy_module, class_name)

                        if not issubclass(custom_strategy, Strategy):
                            tlog(
                                f"custom strartegy must inherit from class {Strategy.__name__}"
                            )
                            exit(0)
                        strategy_details.pop("filename", None)
                        strategy_types += [(custom_strategy, strategy_details)]

                    except Exception as e:
                        tlog(
                            f"[Error]exception of type {type(e).__name__} with args {e.args}"
                        )
                        exit(0)

            for strategy_tuple in strategy_types:
                strategy_type = strategy_tuple[0]
                strategy_details = strategy_tuple[1]
                config.env = "BACKTEST"
                tlog(f"initializing {strategy_type.name}")
                s = strategy_type(batch_id=uid, ref_run_id=ref_run_id,  **strategy_details)
                await s.create()
                trading_data.strategies.append(s)

            for symbol in symbols:
                await backtest_symbol(symbol)

    @timeit
    async def backtest_worker():
        await create_db_connection()
        run_details = await AlgoRun.get_batch_details(batch_id)
        run_ids, starts, _ = zip(*run_details)

        if not len(run_details):
            print(f"can't load data for batch id {batch_id}")
        else:
            await backtest_run(
                start=min(starts),
                duration=timedelta(minutes=max([w['duration'] for w in [dict(list(s.values())[0])['schedule'][0] for s in conf_dict['strategies']]])),
                ref_run_id=run_ids[0],
            )

    try:
        if not asyncio.get_event_loop().is_closed():
            asyncio.get_event_loop().close()
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(asyncio.new_event_loop())
        loop.run_until_complete(backtest_worker())
    except KeyboardInterrupt:
        tlog("backtest() - Caught KeyboardInterrupt")
    except Exception as e:
        tlog(
            f"backtest() - exception of type {type(e).__name__} with args {e.args}"
        )
        traceback.print_exc()
    finally:
        print("=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=")
        print(f"new batch-id: {uid}")


if __name__ == "__main__":
    config.build_label = pygit2.Repository("./").describe(
        describe_strategy=pygit2.GIT_DESCRIBE_TAGS
    )
    config.filename = os.path.basename(__file__)
    conf_dict = toml.load(config.configuration_filename)
    tlog(f"tradeplan file {config.configuration_filename} loaded")
    if len(sys.argv) == 1:
        show_usage()
        sys.exit(0)

    try:
        opts, args = getopt.getopt(
            sys.argv[1:], "vb:d:", ["batch-list", "version", "debug-symbol="]
        )
        debug_symbols = []
        for opt, arg in opts:
            if opt in ("-v", "--version"):
                show_version(config.filename, config.build_label)
                break
            elif opt in ("--batch-list", "-b"):
                get_batch_list()
                break
            elif opt in ("--debug-symbol", "-d"):
                debug_symbols.append(arg)

        for arg in args:
            backtest(arg, debug_symbols, conf_dict)

    except getopt.GetoptError as e:
        print(f"Error parsing options:{e}\n")
        show_usage()
        sys.exit(0)

    sys.exit(0)
