import json
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timedelta
from decimal import Decimal
from pathlib import Path
from typing import Union

import pandas as pd
from tqdm import tqdm

from powerbot_backtesting.models.exporter_models import ApiExporter
from powerbot_backtesting.utils import _order_matching, _battery_order_matching_free, _battery_order_matching_linked


class BacktestingBase(ABC):
    """
    Base structure of Backtesting Algos.
    """

    def _generate_trading_report(self) -> bool:
        """
        Generates a report containing all trades made including prices, quantity and timestamps.

        Returns:
            bool
        """
        cache_path = Path("./__pb_cache__/analysis_output")
        cache_path.mkdir(parents=True, exist_ok=True)

        # File creation
        self.filename = "battery_backtesting_output.json" if not hasattr(self, "filename") else self.filename

        with open(cache_path.joinpath(self.filename.replace("input", "output").replace("csv", "json")), "w") as out:
            json.dump(self.results, out)

        return True

    @abstractmethod
    def match_orders(self, **kwargs) -> Decimal:
        """
        Order matching function of algorithm class
        """
        pass

    def adjust_params(self):
        """
        Overridable function that defines calculation logic on contract level. Any parameters that will be changed in this function will be usable
        for each iteration of the order books.

        This function should only be used to adjust the parameters in the member variable self.params. Please be aware, that deletion of the key
        "position" will lead to an error.
        """
        return None

    @abstractmethod
    def algorithm(self, **kwargs):
        pass


class BacktestingAlgo(BacktestingBase):
    """
    Class aiming to provide a backbone to backtesting strategies.

    This class is used by providing two parameters: the order books and an input file based on these order books.
    The input file can be generated with the command "generate_input_file". The output is a csv file that needs to be
    filled with positions and other signals you wish to utilize in your backtesting strategy.

    Any keyworded arguments you pass will be automatically added as class attributes. This allows you to only overwrite
    a single function and still be able to use all the parameters you need for your strategy.

    To write your own backtesting algorithm, create a new class that inherits from BacktestingAlgo and overwrite
    the algorithm function. All signals/ columns that you have added to the input file will be reflected as dictionary
    keys in self.params. Please be aware, that the column names will be the key names after conversion.

    The most important feature of this class is it's order matching logic. When writing your own algorithm, this
    function needs to be called whenever your calculations are done and you want to see if you would make any trades.

    Any trades you perform alter the order book indirectly and will also be gathered in a list of results.
    """

    def __init__(self,
                 orderbooks: dict[str, pd.DataFrame],
                 file_input: Union[str, Path],
                 trades: dict[str, pd.DataFrame] = None,
                 generate_report: bool = True,
                 **kwargs):
        """
        Args:
            orderbooks (dict): Order books as generated by get_orderbooks
            file_input (str/ Path): Path to input file
            trades (pd.Dataframe): OPTIONAL Trade data for the same contracts as in the order books (acquired via get_public_trades)
            generate_report (bool): True if trade report should be created (highly recommended)
            **kwargs: all additional parameters that are needed for the execution of the algorithm
        """
        self.orderbooks = orderbooks
        self.filename = file_input if isinstance(file_input, str) else str(file_input).split("/")[-1]
        self.all_trades = trades
        self.trade_list = {key: {} for key in [*self.orderbooks]}
        self.results = {key: {} for key in [*self.orderbooks]}
        self.trade_list = []
        self.results = []
        self.generate_report = generate_report

        # Load file input
        cache_path = Path("./__pb_cache__/analysis_input")

        try:
            self.algo_params = pd.read_csv(f"{cache_path}/{file_input}", sep=";")
        except FileNotFoundError:
            self.algo_params = pd.read_csv(file_input, sep=";")

        self.algo_params["delivery_start"] = pd.to_datetime(self.algo_params["delivery_start"], utc=True)
        self.algo_params["delivery_end"] = pd.to_datetime(self.algo_params["delivery_end"], utc=True)
        self.algo_params = self.algo_params.set_index(["delivery_start", "delivery_end"])

        # Convert all keyworded args to attributes
        [setattr(self, key, value) for key, value in kwargs.items()]

    def filter_trades(self, key: str, timestamp: str) -> Union[None, pd.DataFrame]:
        if not self.all_trades:
            return None
        df = self.all_trades[key].loc[self.all_trades[key].exec_time <= timestamp]
        if df.empty:
            return None
        else:
            return df

    def match_orders(self,
                     side: str,
                     orderbook: pd.DataFrame,
                     timestamp: str,
                     price: Union[int, float],
                     position: Union[int, float],
                     contract_time: int,
                     key: str = None,
                     vwap: float = None,
                     order_execution: str = "NON") -> Decimal:
        """
        Wrapper for _order_matching function.

        Matches orders according to input parameters; adds trades made to trade_list and returns the remaining quantity.

        The order_execution parameter can be added to decide according to which logic the quantity should be filled. Allowed
        values are:

        NON - No restriction, partial & full execution are allowed

        FOK - Fill or Kill, if order isn't filled completely by first matching order, next matching order is loaded ->
        if none match next order book is loaded

        IOC - Immediate and Cancel, order is executed to maximum extent by first matching order, next order book is loaded ->
        allows price adjustments

        Args:
            side (str): buy/sell
            orderbook (DataFrame): Single order book
            timestamp (str): Timestamp of order book
            price (int): Minimum/ Maximum Price for Transaction
            position (int, float): Quantity that can/ should be traded
            contract_time (int): contract time in minutes, either 60, 30 or 15
            key (str): Position in trade list
            vwap (float): optional value to display current VWAP in the list of executed trades
            order_execution (str): Type of order execution that should be simulated

        Returns:
            Decimal: remaining quantity
        """
        return _order_matching(side, orderbook, timestamp, price, position, self.exec_orders_list, self.trade_list, contract_time, vwap,
                               order_execution)

    def run(self) -> Union[dict, bool]:
        """
        Main function to execute custom trading logic. Handles all necessary steps and then calls the custom algorithm function.

        Returns:
            dict or bool: generates json trading report or returns contents as dict
        """

        for key, contract in self.orderbooks.reset_index().groupby(["delivery_start", "delivery_end"]):
            if isinstance(contract, pd.DataFrame) and contract.empty:
                self.results[key]["Remaining Quantity"] = "No data"
                continue

            self.exec_orders_list = {}
            self.params = self.algo_params.loc[key].to_dict()

            self.params["contract_time"] = (key[1] - key[0]).seconds / 60

            # Overrideable function to set custom calculation logic on contract level
            self.adjust_params()

            contract = contract.sort_values("timestep").set_index("timestep")
            # Iteration through every timestep/orderbook

            for timestamp, orderbook in contract.groupby(contract.index):
                # Add trades to params for this run
                self.params["trades"] = self.filter_trades(key, timestamp)

                print(timestamp)

                # Overrideable function to set behaviour on order book level
                if not self.algorithm(timestamp, orderbook, key):
                    break

        if len(self.trade_list) > 0:
            self.trades_df = pd.concat([pd.DataFrame([trade]) for trade in self.trade_list]).reset_index(drop=True)
            self.trades_df.set_index(["delivery_start", "delivery_end"])
        else:
            print("No Trades were generated based on the given input signals")

        if self.generate_report and len(self.trade_list) > 0:
            cache_path = Path("./__pb_cache__/analysis_output")
            cache_path.mkdir(parents=True, exist_ok=True)

            # File creation
            self.filename = "battery_backtesting_output.json" if not hasattr(self, "filename") else self.filename

            self.trades_df.to_csv(cache_path.joinpath(self.filename), sep=";")

            return self.trades_df

    @staticmethod
    def calc_vwap(trades: pd.DataFrame,
                  timestamp: str,
                  time_spec: str = "60T-60T-0T") -> float:
        """
        Function to calculate the value-weighted average price at the given point in time for the last X hours.

        To specify the time period precisely, the time_spec parameter should be used. The pattern is always as follows:

        {60/30/15/0}T-{60/45/30/15}T-{45/30/15/0}T

        Explanation:
            {60/30/15/0}T -> Floor, will count back to the last full hour/ half-hour/ quarter-hour / last minute and act as starting point
            {60/45/30/15}T -> Execution From, determines the minutes that should be subtracted from Floor to reach starting point for calculation
            {45/30/15/0}T -> Execution To, determines the minutes that should be subtracted from Floor to reach end point for calculation

        Examples:
            60T-60T-0T <--> VWAP of the previous trading hour.
            60T-15T-0T <--> VWAP of the last quarter hour of the previous trading hour.
            60T-30T-15T <--> VWAP of third quarter-hour of the previous trading hour.
            15T-60T-0T <--> VWAP of last hour calculated from last quarter hour.
            0T-60T-30T <--> VWAP of first half of the last hour calculated from current timestamp.

        Args:
            trades (pd.DataFrame): Collection of trades
            timestamp (str): Current timestamp
            time_spec (str): String of time specification as explained above

        Returns:
            float
        """
        return ApiExporter.vwap_by_timeperiod(trades, timestamp, time_spec)

    @abstractmethod
    def algorithm(self,
                  timestamp: str,
                  orderbook: pd.DataFrame,
                  key: str) -> bool:
        """
        Overridable function that defines behaviour and calculations for a specific strategy. Please make sure to define the same input parameters
        as the original function.

        Since the order matching function already saves every important action it takes, it does not need any manual adjustments.

        The only requirement for this function is that at least the position has to be recalculated according to the return value of the order
        matching function. Every other parameter (e.g. price) can be changed according to the specific trading strategy.

        An example for a simple position closer can be found in the PowerBot Knowledge Base.

        Args:
            timestamp (str): Timestamp of order book
            orderbook (pd.DataFrame): Order book
            key (str): Contract timeframe

        Returns:
            bool
        """
        pass


class BatteryBacktestingAlgo(BacktestingBase):
    """
    Class aiming to provide a backbone to battery backtesting strategies.

    A battery needs to acquire energy in an earlier delivery period to be able to sell it again at a later point in time. This core principle is
    implemented with the use of the PositionStorage class that keeps track of all current positions and ensures that the capacity is never exceeded.

    The algorithm can be parametrized according to a user's strategy. In particular, the match_hours parameter can be used to avoid situations, where
    a battery is filled in the morning and the energy is sold in the evening. This would miss a lot of potential profit, since trading more often
    almost always leads to more profit overall.
    Also the max_perc_per_trade parameter can be used to further control how orders are matched. This parameter sets a limit on how much of the
    maximum capacity can be used in a single transaction (buy, sell, linked buy & sell).

    To write your own battery backtesting algorithm, create a new class that inherits from BatteryBacktestingAlgo and overwrite the algorithm
    function. Any trades you perform alter the order book indirectly and will also be gathered in a list of results.

    There are 2 general modes this algorithm can operate in:
        linked:
            A position can only be opened if there is another order that is able to close the full amount opened. This approach uses the flexibility
            a battery provides but does not speculate on prices. It will not fill the battery without a guaranteed sell.

        free:
            A position can be opened regardless of any other open order. Your maximum capacity cannot be overshot, however.
    """

    def __init__(self,
                 orderbooks: dict[str, pd.DataFrame],
                 battery_capacity: Union[int, float],
                 match_hours: int = 5,
                 max_perc_per_trade: float = 1.0,
                 behaviour: str = "linked",
                 generate_report: bool = True,
                 **kwargs):
        """
        Args:
            orderbooks (dict): Order books as generated by get_orderbooks
            battery_capacity (int): The maximum amount of MW the battery storage can hold
            match_hours (int): The maximum amount of hours a contract can be in the future from the current point in time to be still matchable
            max_perc_per_trade (float): Maximum percentage of available capacity that can be used for a single order
            behaviour (str): Determines if positions can only be opened if there exists a corresponding order in a later contract (linked) or
                positions can be opened as much as battery capacity allows without needing an already existing opposite order (free).
            generate_report (bool): True if trade report should be created (highly recommended)
            **kwargs: all additional parameters that are needed for the execution of the algorithm
        """
        if behaviour not in ["linked", "free"]:
            raise ValueError("Algorithm behaviour has to be either 'linked' or 'free'")
        self.orderbooks = orderbooks
        self.battery_capacity = battery_capacity
        self.positions = PositionStorage(battery_capacity)
        self.current_capacity = 0
        self.match_hours = match_hours
        self.max_perc_per_trade = max_perc_per_trade
        self.generate_report = generate_report
        self.behaviour = behaviour
        self.trade_list = {}
        self.results = {}
        # Collection for already executed orders
        self.exec_orders_list = {}

        # Convert all keyworded args to attributes
        [setattr(self, key, value) for key, value in kwargs.items()]

    def _aggregate_orderbooks(self):
        """
        Function to aggregate multiple order books into a single collection to check against an open order.

        Returns:
            dict[str, pd.DataFrame]
        """
        # Create defaultdict to store aggregated dataframes
        aggregated_orderbooks = defaultdict(pd.DataFrame)

        # Running collection of limits
        limits = {}

        # Insert new column in each dataframe with their specific delivery period
        [df.insert(0, "del_period", k) for k, v in self.orderbooks.items() for ts, df in v.items()]

        # Fill defaultdict by appending dataframes to each other -> appending dataframes does not work inplace :(
        for k, v in tqdm(self.orderbooks.items(), desc="Aggregating Order Books", unit="orderbooks", leave=False):
            for ts, i in v.items():
                if len(i.contract_id.unique().tolist()) > 1:
                    # Prohibit dataframe containing local and XBID contracts
                    raise ValueError("Local and XBID contracts are not allowed to be mixed")

                if aggregated_orderbooks[ts].empty:
                    # Determine the time limit according to match_hours
                    min_period = i.del_period.min().split(" ")
                    from_time = datetime.strptime(f"{min_period[0]} {min_period[1]}", "%Y-%m-%d %H:%M") + timedelta(hours=self.match_hours)
                    to_time = datetime.strptime(min_period[-1], "%H:%M") + timedelta(hours=self.match_hours)
                    limits[ts] = f"{from_time} - {to_time.strftime('%H:%M')}"

                # Continue to save time if limit is reached
                if k > limits.get(ts, str(datetime.now())):
                    continue

                # Add orderbook to aggregation
                aggregated_orderbooks[ts] = aggregated_orderbooks[ts].append(i)

        return aggregated_orderbooks

    def match_orders(self,
                     orderbook: pd.DataFrame,
                     timestamp: str,
                     price: Union[int, float],
                     position: Union[int, float] = None,
                     side: str = "buy",
                     min_spread: Union[int, float] = 20,
                     order_execution: str = "NON") -> Decimal:
        """
        Wrapper for _order_matching function.

        Matches orders according to input parameters; adds trades made to trade_list and returns the remaining quantity.

        The order_execution parameter can be added to decide according to which logic the quantity should be filled. Allowed
        values are:

        NON - No restriction, partial & full execution are allowed

        FOK - Fill or Kill, if order isn't filled completely by first matching order, next matching order is loaded ->
        if none match next order book is loaded

        IOC - Immediate and Cancel, order is executed to maximum extent by first matching order, next order book is loaded ->
        allows price adjustments

        Args:
            orderbook (DataFrame): Single order book
            timestamp (str): Timestamp of order book
            price (int): Minimum/ Maximum Price for Transaction
            position (int, float): Quantity that can/ should be traded
            side (str): buy/sell
            min_spread (int/float): Minimum spread between a linked buy and sell transaction
            order_execution (str): Type of order execution that should be simulated

        Returns:
            Decimal: remaining quantity
        """
        if self.behaviour == "free":
            return _battery_order_matching_free(side, orderbook, timestamp, price, position, self.exec_orders_list, self.trade_list, self.positions,
                                                order_execution)
        else:
            return _battery_order_matching_linked(orderbook, timestamp, price, self.max_perc_per_trade, self.exec_orders_list, self.trade_list,
                                                  self.positions, min_spread)

    def __run(self, dataframe):
        # Overrideable function to set custom calculation logic that can be applied progressively over time
        self.adjust_params()

        # Extract timestamp
        timestamp = dataframe.ts.unique().tolist()[0]

        # Lock & unlock quantities
        self.positions.un_lock_by_timestamp(timestamp)

        # Overrideable function to set behaviour on order book level
        self.algorithm(timestamp, dataframe)

    def run(self) -> Union[dict, bool]:
        """
        Main function to execute custom trading logic. Handles all necessary steps and then calls the custom algorithm function.

        Returns:
            dict or bool: generates json trading report or returns contents as dict
        """
        # Aggregating single order books
        self.orderbooks = self._aggregate_orderbooks()

        # Apply function on dataframe
        df = pd.concat(self.orderbooks)
        df.reset_index(inplace=True)
        df.rename(columns={"level_0": 'ts'}, inplace=True)
        df.drop(columns=["level_1"], inplace=True)
        # Run custom function on aggregated order books
        df.groupby(df.ts).apply(self.__run)

        # Create summary
        self.results["Battery Capacity (MW/h)"] = self.battery_capacity
        self.results["Current Capacity (MW/h)"] = self.positions.filled()
        self.results["Total Cash"] = round(sum([value["Cash"] for value in self.trade_list.values()]), 2)
        self.results["Trades"] = self.trade_list

        if self.generate_report:
            return self._generate_trading_report()

        return self.results

    @abstractmethod
    def algorithm(self,
                  timestamp: str,
                  orderbook: pd.DataFrame) -> bool:
        """
        Overridable function that defines behaviour and calculations for a specific strategy. Please make sure to define the same input parameters
        as the original function.

        Since the order matching function already saves every important action it takes, it does not need any manual adjustments.

        Args:
            timestamp (str): Timestamp of order book
            orderbook (pd.DataFrame): Order book

        Returns:
            bool
        """
        pass


class PositionStorage:
    """
    Class to keep track and manage positions for a battery optimizing algorithm.
    """

    def __init__(self, max_capacity: int):
        self.max_cap = max_capacity
        self.positions = defaultdict(int)
        self.locked = []

    def __getitem__(self, key):
        return self.positions[key]

    def __iter__(self):
        return self.positions.__iter__()

    def __str__(self):
        return str(self.positions)

    def charge(self, timestamp: str, amount: Union[int, float]):
        if self.is_locked(timestamp):
            raise ValueError("Requested position is locked!")
        elif not self.__theoretically_available(timestamp, amount, "charge"):
            raise ValueError("Position would exceed maximum Capacity!")

        self.positions[timestamp] = round(self.positions[timestamp] + amount, 3)

    def discharge(self, timestamp: str, amount: Union[int, float]):
        if self.is_locked(timestamp):
            raise ValueError("Requested position is locked!")
        elif not self.__theoretically_available(timestamp, amount, "discharge"):
            raise ValueError("Battery is currently not loaded")

        self.positions[timestamp] = round(self.positions[timestamp] - amount, 3)

    def items(self):
        return self.positions.items().__iter__()

    def values(self):
        return self.positions.values().__iter__()

    def keys(self):
        return self.positions.__iter__()

    def is_locked(self, key) -> bool:
        return True if key in self.locked else False

    def un_lock_by_timestamp(self, timestamp: str):
        """
        Function locks and unlocks keys wherever necessary according to given timestamp

        Args:
            timestamp (str): Current timestamp
        """
        lock, unlock = self.__parse_timestamp_to_keys(timestamp)
        [self.locked.append(k) for k in lock if k not in self.locked]
        [self.__unlock(k) for k in unlock if k in self.locked]

    def available(self, key: str = None) -> float:
        """
        Returns amount available to buy until maximum capacity is reached.

        If key is given, function will return the maximum capacity subtracted by the sum of all positions earlier than the key.

        Args:
            key (str): Contract period to check against

        Returns:
            float
        """
        # All positives until first negative -> this might only be viable for linking strategies
        if key and (negatives := {k: v for k, v in self.positions.items() if v < 0}):
            return round(self.max_cap - sum(v for k, v in self.positions.items() if k < min(negatives)), 3)

        return round(self.max_cap - sum(v for v in self.positions.values()), 3)

    def filled(self, key: str = None) -> float:
        """
        Returns amount available to sell until current capacity is depleted.

        If key is given, function will return the sum of all positions earlier than the key.

        Args:
            key (str): Contract delivery period to look for

        Returns:
            float
        """
        if not key:
            return round(sum(v for v in self.positions.values()), 3)

        # Capacity that has been bought from a later contract cannot be sold on an earlier contract
        return round(sum(v for k, v in self.positions.items() if k < key), 3)

    def __unlock(self, key):
        # Remove from locked list
        self.locked = [i for i in self.locked if i != key]
        # Add to specific key '0' -> stored energy that is always available to sell
        self.positions["0"] = round(self.positions["0"] + self.positions[key], 3)
        # Reset key
        self.positions[key] = 0

    def __theoretically_available(self, key: str, new_value: int, direction: str):
        # Each individual time cannot overshot max cap
        if direction == "charge":
            # can't charge a full battery or overshot max cap
            if not self.available() or self.positions[key] == self.max_cap:
                return False

            if key not in self.positions:
                return ((self.max_cap - new_value) - sum(self.positions.values())) >= 0

            return self.max_cap - sum(v if not k == key else new_value for k, v in self.positions.items()) >= 0

        elif direction == "discharge":
            # Can't discharge an empty battery or overshot max cap
            if not self.filled() or self.positions[key] == -self.max_cap:
                return False

            if key not in self.positions:
                return self.max_cap >= ((self.max_cap - new_value) - sum(self.positions.values())) >= 0

            return self.max_cap - sum(v if not k == key else -new_value for k, v in self.positions.items()) >= 0

    def __parse_timestamp_to_keys(self, timestamp: str) -> tuple[list, list]:
        ts = pd.Timestamp(timestamp)
        freqs = [15, 30, 60]
        timestamps = [ts.floor(freq=f"{f}T") for f in freqs]

        to_be_locked = [f"{t.strftime('%Y-%m-%d %H:%M')} - {(t + pd.Timedelta(minutes=x)).strftime('%H:%M')}" for t, x in zip(timestamps, freqs)]
        to_be_unlocked = [f"{(t - pd.Timedelta(minutes=x)).strftime('%Y-%m-%d %H:%M')} - {t.strftime('%H:%M')}" for t, x in zip(timestamps, freqs)]

        return to_be_locked, to_be_unlocked
