"""Main module of the package"""
# Standard library imports
import logging
from typing import Optional, Any, Union
import datetime
import time

# Third party imports
import pandas as pd
from tqdm.auto import tqdm

# Local imports

# Global constants
LOGGER = logging.getLogger(__name__)


def run_backtest(
        df_positions_short : pd.DataFrame,
        df_execution_prices_full : pd.DataFrame,
        is_to_neutralize : bool=True,
        td_trading_delay : Optional[datetime.timedelta]=None,
        td_execution_duration : Optional[datetime.timedelta]=None,
        const_trading_fees_percent : float=0.01,
) -> pd.DataFrame:
    """Backtest positions to understand if they can generate PNL

    Args:
        df_positions_short (pd.DataFrame): Positions we want to take
        df_execution_prices_full (pd.DataFrame): Prices of assets in higher resolution
        is_to_neutralize (bool): Flag if to have long-short equal positions
        td_trading_delay (datetime.timedelta): \
            Delay needed to get into the wanted positions
        td_execution_duration ([datetime.timedelta]): \
            How long should the execution take
        const_trading_fees_percent (float): Broker commission fee in percent
        orderbook ([TBD]): orderbook object to calculate execution price by it

    Returns:
        pd.DataFrame: columns with different PNLs generated at every tick
    """
    LOGGER.debug("Run backtest")
    # print(df_positions_short.T.tail(3).T.tail(3).to_html())
    # print(df_execution_prices_full.T.tail(3).T.tail(3).to_html())
    df_backtest_res = pd.DataFrame()
    # Neutralize
    if is_to_neutralize:
        LOGGER.debug("---> Neutralize positions for every datetime")
        df_positions_short = _neutralize(df_positions_short)
    # Scale to 1.0
    LOGGER.debug("---> Scale positions to 1 for every datetime")
    df_positions_short = _scale(df_positions_short)
    # Add delay which was needed to generate this positions
    if td_trading_delay is not None:
        LOGGER.debug("---> Add trading delay to the wanted positions")
        LOGGER.debug(
            "------> Last position time before delay: %s",
            df_positions_short.index[-1]
        )
        df_positions_short = df_positions_short.shift(freq=td_trading_delay)
        LOGGER.debug(
            "------> Last position time after delay: %s",
            df_positions_short.index[-1]
        )
    # Convert full df with prices to short format of reachable prices
    # at the moments of position change (No bias)
    LOGGER.debug("---> Convert execution prices to short format")
    df_perfect_exec_prices_short = _convert_full_prices_df_to_short(
        df_positions_short, df_execution_prices_full)
    LOGGER.debug("------> Done")
    # Calculate holding pnl
    # which is a pnl generated by positions taken in the past
    LOGGER.debug("---> Calculate holding PNL")
    df_backtest_res["PNL before_costs"] = _calc_ser_holding_pnl(
        df_positions_short, df_perfect_exec_prices_short)
    # DataFrame how we would like to change our current positions
    LOGGER.debug("---> Calculate the wanted position change on every tick")
    df_pos_change_wanted = df_positions_short - df_positions_short.shift(1)
    # Calculate prices by which we can execute our position change
    LOGGER.debug("---> Calculate execution prices")
    df_real_exec_prices_short = calc_execution_price_rough(
        df_positions_short,
        df_execution_prices_full,
        td_execution_duration
    )
    LOGGER.debug("------> Done")
    # Execution fee which is paid because
    # we can't execute by the best available price now
    df_execution_fee = \
        df_pos_change_wanted * (df_real_exec_prices_short - df_perfect_exec_prices_short)
    df_backtest_res["PNL execution_fee"] = df_execution_fee.sum(axis=1)
    # Volume traded at the current tick
    df_backtest_res["trading_volume"] = \
        df_pos_change_wanted.abs().sum(axis=1)
    # Brokers trading commission
    df_backtest_res["PNL const_trading_fee"] = (
        df_backtest_res["trading_volume"] *
        const_trading_fees_percent / 100.0
    )
    # Add column with with final PNL results
    df_backtest_res["PNL after_costs"] = (
        df_backtest_res["PNL before_costs"] -
        df_backtest_res["PNL execution_fee"] -
        df_backtest_res["PNL const_trading_fee"]
    )
    df_backtest_res["PNL half_costs"] = (
        df_backtest_res["PNL before_costs"] +
        df_backtest_res["PNL after_costs"]
    ) / 2.0
    # Add more columns with info
    df_backtest_res["booksize"] = df_positions_short.abs().sum(axis=1)
    df_backtest_res["max weight"] = df_positions_short.abs().max(axis=1)


    df_backtest_res["long count"] = \
        df_positions_short[df_positions_short > 0].count(axis=1)
    df_backtest_res["short count"] = \
        df_positions_short[df_positions_short < 0].count(axis=1)

    df_backtest_res["long value"] = \
        df_positions_short[df_positions_short > 0].sum(axis=1).abs()
    df_backtest_res["short value"] = \
        df_positions_short[df_positions_short < 0].sum(axis=1).abs()


    return change_columns_order(df_backtest_res)


def change_columns_order(df_backtest_res):
    """"""
    columns_ordered = []
    columns_pnl = []
    for column in df_backtest_res.columns:
        if "PNL" in column:
            columns_pnl.append(column)
            continue
        columns_ordered.append(column)
    df_backtest_res = df_backtest_res[columns_pnl + columns_ordered]
    return df_backtest_res


def calc_execution_price_rough(
        df_positions_short : pd.DataFrame,
        df_execution_prices_full : pd.DataFrame,
        td_execution_duration : datetime.timedelta
) -> pd.DataFrame:
    """Get execution price as mean price over execution duration

    Args:
        df_positions_short (pd.DataFrame): Positions we want to take
        df_execution_prices_full (pd.DataFrame): Prices of assets in higher resolution
        td_execution_duration ([datetime.timedelta]): \
            How long should the execution take

    Returns:
        pd.DataFrame: Prices by which asset can be bought at any moment
    """
    # Reverse the df with prices because rolling mean price
    # Should go into future not the past and then at the end reverse back
    LOGGER.debug("------> Calculate full execution prices")

    if td_execution_duration:
        df_exec_price_full = df_execution_prices_full[::-1].rolling(
            td_execution_duration, min_periods=2).mean()[::-1]
    else:
        df_exec_price_full = df_execution_prices_full[::-1].rolling(2).mean()[::-1]
    LOGGER.debug("------> Convert them into short format")
    # Convert Execution prices to the short format
    df_exec_prices_short = _convert_full_prices_df_to_short(
        df_positions_short, df_exec_price_full)
    LOGGER.debug("---------> Done")
    return df_exec_prices_short


def _calc_ser_holding_pnl(
        df_positions_short : pd.DataFrame,
        df_execution_prices_short : pd.DataFrame
) -> pd.Series:
    """Calculate PNL generated by holding positions taken at last tick"""
    LOGGER.debug("------> Shift positions on 1 tick backward")

    df_previous_pos = df_positions_short.shift(1)
    df_prices_change_pct = \
        (df_execution_prices_short / df_execution_prices_short.shift(1)) - 1.0
    LOGGER.debug(
        "------> Multiply previous tick positions on price change percent")
    df_holding_pnl = df_previous_pos.multiply(df_prices_change_pct)
    return df_holding_pnl.sum(axis=1)


def _convert_full_prices_df_to_short(
        df_positions_short : pd.DataFrame,
        df_execution_prices_full : pd.DataFrame,
) -> pd.DataFrame:
    """Convert prices in high resolution to the resolution of short DFs"""
    set_index_short = set(df_positions_short.index.tolist())
    set_index_full = set(df_execution_prices_full.index.tolist())
    set_missing_ticks_in_full = set_index_short - set_index_full
    set_common_ticks = set_index_short.intersection(set_index_full)

    if set_missing_ticks_in_full:
        LOGGER.warning("Missing ticks: %d", len(set_missing_ticks_in_full))

    if len(set_missing_ticks_in_full) > 5:
        list_missing_ticks = sorted(set_missing_ticks_in_full)
        LOGGER.warning("Missing ticks: %s", list_missing_ticks[:20])
        raise ValueError(
            f"Unable to convert full df into short format as "
            f"{len(set_missing_ticks_in_full)}/{len(set_index_short)} "
            "ticks missing"
        )
    return df_execution_prices_full.loc[sorted(set_common_ticks)]


def _neutralize(df_positions_short : pd.DataFrame) -> pd.DataFrame:
    """Neutralize positions to be long-short equal"""
    return df_positions_short.sub(df_positions_short.mean(axis=1), axis=0)


def _scale(df_positions_short : pd.DataFrame) -> pd.DataFrame:
    """
    Scale to have sum of absolute positions at every tick equals to 1.0
    """
    return df_positions_short.div(df_positions_short.abs().sum(axis=1), axis=0)
