# Copyright 2018 QuantRocket LLC - All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# To run: python3 -m unittest discover -s tests/ -p test_*.py -t . -v

import os
import unittest
from unittest.mock import patch
import glob
import pickle
from pathlib import Path
import inspect
import pandas as pd
import numpy as np
from moonshot import Moonshot, MoonshotML
from moonshot.cache import TMP_DIR
from quantrocket.exceptions import ImproperlyConfigured
from sklearn.tree import DecisionTreeClassifier

class HistoricalPricesCacheTestCase(unittest.TestCase):

    def test_10_complain_if_houston_not_set(self):
        """
        Tests that a "HOUSTON_URL not set" error is raised if a backtest is
        run without mock. This is a control for later tests.
        """
        # clear cache dir if any pickles are hanging around
        files = glob.glob("{0}/moonshot_*.pkl".format(TMP_DIR))
        for file in files:
            os.remove(file)

        class BuyBelow10(Moonshot):
            """
            A basic test strategy that buys below 10.
            """

            def prices_to_signals(self, prices):
                signals = prices.loc["Close"] < 10
                return signals.astype(int)

        with self.assertRaises(ImproperlyConfigured) as cm:

            BuyBelow10().backtest()

        self.assertIn("HOUSTON_URL is not set", repr(cm.exception))

    def test_20_load_history_from_mock(self):
        """
        Runs a strategy using mock to fill the history cache.
        """

        class BuyBelow10(Moonshot):
            """
            A basic test strategy that buys below 10.
            """

            def prices_to_signals(self, prices):
                signals = prices.loc["Close"] < 10
                return signals.astype(int)

        def mock_get_prices(*args, **kwargs):

            dt_idx = pd.DatetimeIndex(["2018-05-01","2018-05-02","2018-05-03", "2018-05-04"])
            fields = ["Close","Volume"]
            idx = pd.MultiIndex.from_product([fields, dt_idx], names=["Field", "Date"])

            prices = pd.DataFrame(
                {
                    "FI12345": [
                        # Close
                        9,
                        11,
                        10.50,
                        9.99,
                        # Volume
                        5000,
                        16000,
                        8800,
                        9900
                    ],
                    "FI23456": [
                        # Close
                        9.89,
                        11,
                        8.50,
                        10.50,
                        # Volume
                        15000,
                        14000,
                        28800,
                        17000

                    ],
                 },
                index=idx
            )

            return prices

        def mock_download_master_file(f, *args, **kwargs):

            master_fields = ["Timezone", "Symbol", "SecType", "Currency", "PriceMagnifier", "Multiplier"]
            securities = pd.DataFrame(
                {
                    "FI12345": [
                        "America/New_York",
                        "ABC",
                        "STK",
                        "USD",
                        None,
                        None
                    ],
                    "FI23456": [
                        "America/New_York",
                        "DEF",
                        "STK",
                        "USD",
                        None,
                        None,
                    ]
                },
                index=master_fields
            )
            securities.columns.name = "Sid"
            securities.T.to_csv(f, index=True, header=True)
            f.seek(0)

        with patch("moonshot.strategies.base.get_prices", new=mock_get_prices):
            with patch("moonshot.strategies.base.download_master_file", new=mock_download_master_file):
                results = BuyBelow10().backtest(end_date="2018-05-04")

        self.assertSetEqual(
            set(results.index.get_level_values("Field")),
            {'Commission',
             'AbsExposure',
             'Signal',
             'Return',
             'Slippage',
             'NetExposure',
             'TotalHoldings',
             'Turnover',
             'AbsWeight',
             'Weight'}
        )

        # replace nan with "nan" to allow equality comparisons
        results = results.round(7)
        results = results.where(results.notnull(), "nan")

        signals = results.loc["Signal"].reset_index()
        signals.loc[:, "Date"] = signals.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            signals.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [1.0,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [1.0,
                     0.0,
                     1.0,
                     0.0]}
        )

        weights = results.loc["Weight"].reset_index()
        weights.loc[:, "Date"] = weights.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            weights.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.5,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [0.5,
                     0.0,
                     1.0,
                     0.0]}
        )

        abs_weights = results.loc["AbsWeight"].reset_index()
        abs_weights.loc[:, "Date"] = abs_weights.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            abs_weights.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.5,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [0.5,
                     0.0,
                     1.0,
                     0.0]}
        )

        net_positions = results.loc["NetExposure"].reset_index()
        net_positions.loc[:, "Date"] = net_positions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            net_positions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.0,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.0,
                     1.0]}
        )

        abs_positions = results.loc["AbsExposure"].reset_index()
        abs_positions.loc[:, "Date"] = abs_positions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            abs_positions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.0,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.0,
                     1.0]}
        )

        turnover = results.loc["Turnover"].reset_index()
        turnover.loc[:, "Date"] = turnover.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            turnover.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.5,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.5,
                     1.0]}
        )

        commissions = results.loc["Commission"].reset_index()
        commissions.loc[:, "Date"] = commissions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            commissions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     0.0,
                     0.0],
             "FI23456": [0.0,
                     0.0,
                     0.0,
                     0.0]}
        )

        slippage = results.loc["Slippage"].reset_index()
        slippage.loc[:, "Date"] = slippage.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            slippage.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     0.0,
                     0.0],
             "FI23456": [0.0,
                     0.0,
                     0.0,
                     0.0]}
        )

        returns = results.loc["Return"]
        returns = returns.reset_index()
        returns.loc[:, "Date"] = returns.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            returns.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     -0.0227273, # (10.50 - 11)/11 * 0.5
                     -0.0],
             "FI23456": [0.0,
                     0.0,
                     -0.1136364, # (8.50 - 11)/11 * 0.5
                     0.0]}
        )

    def test_30_load_history_from_cache(self):
        """
        Re-Runs the strategy without using mock to show that the history
        cache is used.
        """

        class BuyBelow10(Moonshot):
            """
            A basic test strategy that buys below 10.
            """

            def prices_to_signals(self, prices):
                signals = prices.loc["Close"] < 10
                return signals.astype(int)

        results = BuyBelow10().backtest(end_date="2018-05-04")

        self.assertSetEqual(
            set(results.index.get_level_values("Field")),
            {'Commission',
             'AbsExposure',
             'Signal',
             'Return',
             'Slippage',
             'NetExposure',
             'TotalHoldings',
             'Turnover',
             'AbsWeight',
             'Weight'}
        )

        # replace nan with "nan" to allow equality comparisons
        results = results.round(7)
        results = results.where(results.notnull(), "nan")

        signals = results.loc["Signal"].reset_index()
        signals.loc[:, "Date"] = signals.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            signals.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [1.0,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [1.0,
                     0.0,
                     1.0,
                     0.0]}
        )

        weights = results.loc["Weight"].reset_index()
        weights.loc[:, "Date"] = weights.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            weights.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.5,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [0.5,
                     0.0,
                     1.0,
                     0.0]}
        )

        abs_weights = results.loc["AbsWeight"].reset_index()
        abs_weights.loc[:, "Date"] = abs_weights.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            abs_weights.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.5,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [0.5,
                     0.0,
                     1.0,
                     0.0]}
        )

        net_positions = results.loc["NetExposure"].reset_index()
        net_positions.loc[:, "Date"] = net_positions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            net_positions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.0,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.0,
                     1.0]}
        )

        abs_positions = results.loc["AbsExposure"].reset_index()
        abs_positions.loc[:, "Date"] = abs_positions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            abs_positions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.0,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.0,
                     1.0]}
        )

        turnover = results.loc["Turnover"].reset_index()
        turnover.loc[:, "Date"] = turnover.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            turnover.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.5,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.5,
                     1.0]}
        )

        commissions = results.loc["Commission"].reset_index()
        commissions.loc[:, "Date"] = commissions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            commissions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     0.0,
                     0.0],
             "FI23456": [0.0,
                     0.0,
                     0.0,
                     0.0]}
        )

        slippage = results.loc["Slippage"].reset_index()
        slippage.loc[:, "Date"] = slippage.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            slippage.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     0.0,
                     0.0],
             "FI23456": [0.0,
                     0.0,
                     0.0,
                     0.0]}
        )

        returns = results.loc["Return"]
        returns = returns.reset_index()
        returns.loc[:, "Date"] = returns.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            returns.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     -0.0227273, # (10.50 - 11)/11 * 0.5
                     -0.0],
             "FI23456": [0.0,
                     0.0,
                     -0.1136364, # (8.50 - 11)/11 * 0.5
                     0.0]}
        )

    def test_40_dont_use_cache_if_different_params(self):
        """
        Re-runs the strategy without using mock and specifying different DB
        parameters so as not to use the cache, which should trigger
        ImproperlyConfigured.
        """

        class BuyBelow10(Moonshot):
            """
            A basic test strategy that buys below 10.
            """

            DB_FIELDS = ["Open"]

            def prices_to_signals(self, prices):
                signals = prices.loc["Close"] < 10
                return signals.astype(int)

        with self.assertRaises(ImproperlyConfigured) as cm:

            BuyBelow10().backtest(end_date="2018-05-04")

        self.assertIn("HOUSTON_URL is not set", repr(cm.exception))


    def test_50_dont_use_cache_if_no_cache(self):
        """
        Re-runs the strategy without using mock and with the same DB
        parameters but with no_cache=True, which should not use the cache and
        thus should trigger ImproperlyConfigured.
        """

        class BuyBelow10(Moonshot):
            """
            A basic test strategy that buys below 10.
            """

            def prices_to_signals(self, prices):
                signals = prices.loc["Close"] < 10
                return signals.astype(int)

        with self.assertRaises(ImproperlyConfigured) as cm:

            BuyBelow10().backtest(end_date="2018-05-04", no_cache=True)

        self.assertIn("HOUSTON_URL is not set", repr(cm.exception))

    def test_60_use_cache_if_end_date_and_db_modified(self):
        """
        Tests that if an end date is specified, the cache is used, even
        though we pretend that the db was modified after the file was cached.
        """

        class BuyBelow10(Moonshot):
            """
            A basic test strategy that buys below 10.
            """

            def prices_to_signals(self, prices):
                signals = prices.loc["Close"] < 10
                return signals.astype(int)

        def mock_list_databases(**kwargs):
            return {
                "postgres": [],
                "sqlite": [{'last_modified': "2015-01-01T13:45:00",
                            'name': 'quantrocket.history.my-db1.sqlite',
                            'path': '/var/lib/quantrocket/quantrocket.history.my-db1.sqlite',
                            'size_in_mb': 3.1},
                           # Database was recently modified (in future)
                           {'last_modified': (pd.Timestamp.now() + pd.Timedelta(seconds=60)).isoformat(),
                            'name': 'quantrocket.history.my-db2.sqlite',
                            'path': '/var/lib/quantrocket/quantrocket.history.my-db2.sqlite',
                            'size_in_mb': 2.1},
                           ]}

        with patch("moonshot.cache.list_databases", new=mock_list_databases):
            results = BuyBelow10().backtest(end_date="2018-05-04")

        self.assertSetEqual(
            set(results.index.get_level_values("Field")),
            {'Commission',
             'AbsExposure',
             'Signal',
             'Return',
             'Slippage',
             'NetExposure',
             'TotalHoldings',
             'Turnover',
             'AbsWeight',
             'Weight'}
        )

        # replace nan with "nan" to allow equality comparisons
        results = results.round(7)
        results = results.where(results.notnull(), "nan")

        signals = results.loc["Signal"].reset_index()
        signals.loc[:, "Date"] = signals.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            signals.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [1.0,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [1.0,
                     0.0,
                     1.0,
                     0.0]}
        )

        weights = results.loc["Weight"].reset_index()
        weights.loc[:, "Date"] = weights.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            weights.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.5,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [0.5,
                     0.0,
                     1.0,
                     0.0]}
        )

        abs_weights = results.loc["AbsWeight"].reset_index()
        abs_weights.loc[:, "Date"] = abs_weights.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            abs_weights.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.5,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [0.5,
                     0.0,
                     1.0,
                     0.0]}
        )

        net_positions = results.loc["NetExposure"].reset_index()
        net_positions.loc[:, "Date"] = net_positions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            net_positions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.0,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.0,
                     1.0]}
        )

        abs_positions = results.loc["AbsExposure"].reset_index()
        abs_positions.loc[:, "Date"] = abs_positions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            abs_positions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.0,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.0,
                     1.0]}
        )

        turnover = results.loc["Turnover"].reset_index()
        turnover.loc[:, "Date"] = turnover.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            turnover.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.5,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.5,
                     1.0]}
        )

        commissions = results.loc["Commission"].reset_index()
        commissions.loc[:, "Date"] = commissions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            commissions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     0.0,
                     0.0],
             "FI23456": [0.0,
                     0.0,
                     0.0,
                     0.0]}
        )

        slippage = results.loc["Slippage"].reset_index()
        slippage.loc[:, "Date"] = slippage.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            slippage.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     0.0,
                     0.0],
             "FI23456": [0.0,
                     0.0,
                     0.0,
                     0.0]}
        )

        returns = results.loc["Return"]
        returns = returns.reset_index()
        returns.loc[:, "Date"] = returns.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            returns.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     -0.0227273, # (10.50 - 11)/11 * 0.5
                     -0.0],
             "FI23456": [0.0,
                     0.0,
                     -0.1136364, # (8.50 - 11)/11 * 0.5
                     0.0]}
        )

    def test_70_load_history_from_mock(self):
        """
        Runs a strategy with no end date, using mock, to fill the history
        cache. This test is setup for later tests.
        """

        class BuyBelow10(Moonshot):
            """
            A basic test strategy that buys below 10.
            """

            def prices_to_signals(self, prices):
                signals = prices.loc["Close"] < 10
                return signals.astype(int)

        def mock_get_prices(*args, **kwargs):

            dt_idx = pd.DatetimeIndex(["2018-05-01","2018-05-02","2018-05-03", "2018-05-04"])
            fields = ["Close","Volume"]
            idx = pd.MultiIndex.from_product([fields, dt_idx], names=["Field", "Date"])

            prices = pd.DataFrame(
                {
                    "FI12345": [
                        # Close
                        9,
                        11,
                        10.50,
                        9.99,
                        # Volume
                        5000,
                        16000,
                        8800,
                        9900
                    ],
                    "FI23456": [
                        # Close
                        9.89,
                        11,
                        8.50,
                        10.50,
                        # Volume
                        15000,
                        14000,
                        28800,
                        17000

                    ],
                 },
                index=idx
            )

            return prices

        def mock_download_master_file(f, *args, **kwargs):

            master_fields = ["Timezone", "Symbol", "SecType", "Currency", "PriceMagnifier", "Multiplier"]
            securities = pd.DataFrame(
                {
                    "FI12345": [
                        "America/New_York",
                        "ABC",
                        "STK",
                        "USD",
                        None,
                        None
                    ],
                    "FI23456": [
                        "America/New_York",
                        "DEF",
                        "STK",
                        "USD",
                        None,
                        None,
                    ]
                },
                index=master_fields
            )
            securities.columns.name = "Sid"
            securities.T.to_csv(f, index=True, header=True)
            f.seek(0)

        with patch("moonshot.strategies.base.get_prices", new=mock_get_prices):
            with patch("moonshot.strategies.base.download_master_file", new=mock_download_master_file):
                results = BuyBelow10().backtest()

        self.assertSetEqual(
            set(results.index.get_level_values("Field")),
            {'Commission',
             'AbsExposure',
             'Signal',
             'Return',
             'Slippage',
             'NetExposure',
             'TotalHoldings',
             'Turnover',
             'AbsWeight',
             'Weight'}
        )

        # replace nan with "nan" to allow equality comparisons
        results = results.round(7)
        results = results.where(results.notnull(), "nan")

        signals = results.loc["Signal"].reset_index()
        signals.loc[:, "Date"] = signals.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            signals.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [1.0,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [1.0,
                     0.0,
                     1.0,
                     0.0]}
        )

        weights = results.loc["Weight"].reset_index()
        weights.loc[:, "Date"] = weights.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            weights.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.5,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [0.5,
                     0.0,
                     1.0,
                     0.0]}
        )

        returns = results.loc["Return"]
        returns = returns.reset_index()
        returns.loc[:, "Date"] = returns.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            returns.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     -0.0227273, # (10.50 - 11)/11 * 0.5
                     -0.0],
             "FI23456": [0.0,
                     0.0,
                     -0.1136364, # (8.50 - 11)/11 * 0.5
                     0.0]}
        )

    def test_80_load_history_from_cache_if_no_end_date_and_db_not_modified(self):
        """
        Runs a strategy without using mock to show that the history cache is
        used. Because no end_date is passed, list_databases is consulted, but
        it reports that the db was not recently modified.
        """

        class BuyBelow10(Moonshot):
            """
            A basic test strategy that buys below 10.
            """

            def prices_to_signals(self, prices):
                signals = prices.loc["Close"] < 10
                return signals.astype(int)

        def mock_list_databases(**kwargs):
            return {
                "postgres": [],
                "sqlite": [{'last_modified': "2015-01-01T13:45:00",
                            'name': 'quantrocket.history.my-db1.sqlite',
                            'path': '/var/lib/quantrocket/quantrocket.history.my-db1.sqlite',
                            'size_in_mb': 3.1},
                           {'last_modified': "2016-01-01T04:04:00",
                            'name': 'quantrocket.history.my-db2.sqlite',
                            'path': '/var/lib/quantrocket/quantrocket.history.my-db2.sqlite',
                            'size_in_mb': 2.1},
                           ]}

        with patch("moonshot.cache.list_databases", new=mock_list_databases):
            results = BuyBelow10().backtest()

        self.assertSetEqual(
            set(results.index.get_level_values("Field")),
            {'Commission',
             'AbsExposure',
             'Signal',
             'Return',
             'Slippage',
             'NetExposure',
             'TotalHoldings',
             'Turnover',
             'AbsWeight',
             'Weight'}
        )

        # replace nan with "nan" to allow equality comparisons
        results = results.round(7)
        results = results.where(results.notnull(), "nan")

        signals = results.loc["Signal"].reset_index()
        signals.loc[:, "Date"] = signals.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            signals.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [1.0,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [1.0,
                     0.0,
                     1.0,
                     0.0]}
        )

        weights = results.loc["Weight"].reset_index()
        weights.loc[:, "Date"] = weights.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            weights.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.5,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [0.5,
                     0.0,
                     1.0,
                     0.0]}
        )

        returns = results.loc["Return"]
        returns = returns.reset_index()
        returns.loc[:, "Date"] = returns.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            returns.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     -0.0227273, # (10.50 - 11)/11 * 0.5
                     -0.0],
             "FI23456": [0.0,
                     0.0,
                     -0.1136364, # (8.50 - 11)/11 * 0.5
                     0.0]}
        )

    def test_90_dont_use_cache_if_no_end_date_and_db_modified(self):
        """
        Re-runs the strategy with the same DB parameters but with a (mocked)
        db last modified timestamp which is after the cached file timestamp,
        which should cause a cache miss and thus should trigger
        ImproperlyConfigured.
        """

        class BuyBelow10(Moonshot):
            """
            A basic test strategy that buys below 10.
            """
            def prices_to_signals(self, prices):
                signals = prices.loc["Close"] < 10
                return signals.astype(int)

        def mock_list_databases(**kwargs):
            return {
                "postgres": [],
                "sqlite": [{'last_modified': "2015-01-01T13:45:00",
                            'name': 'quantrocket.history.my-db1.sqlite',
                            'path': '/var/lib/quantrocket/quantrocket.history.my-db1.sqlite',
                            'size_in_mb': 3.1},
                           # Database was recently modified (in future)
                           {'last_modified': (pd.Timestamp.now() + pd.Timedelta(seconds=60)).isoformat(),
                            'name': 'quantrocket.history.my-db2.sqlite',
                            'path': '/var/lib/quantrocket/quantrocket.history.my-db2.sqlite',
                            'size_in_mb': 2.1},
                           ]}

        with patch("moonshot.cache.list_databases", new=mock_list_databases):
            with self.assertRaises(ImproperlyConfigured) as cm:

                BuyBelow10().backtest()

        self.assertIn("HOUSTON_URL is not set", repr(cm.exception))

        # Finally, remove cached files
        for file in glob.glob("{0}/moonshot*.pkl".format(TMP_DIR)):
            os.remove(file)

class MLFeaturesCacheTestCase(unittest.TestCase):

    def setUp(self):
        """
        Trains a scikit-learn model.
        """
        self.model = DecisionTreeClassifier()
        # Predict Y will be same as X
        X = np.array([[1,1],[0,0]])
        Y = np.array([1,0])
        self.model.fit(X, Y)
        self.pickle_path = "{0}/decision_tree_model.pkl".format(TMP_DIR)
        # pickle model
        with open(self.pickle_path, "wb") as f:
            pickle.dump(self.model, f)

    def tearDown(self):
        os.remove(self.pickle_path)

    def test_10_complain_if_houston_not_set(self):
        """
        Tests that a "HOUSTON_URL not set" error is raised if a backtest is
        run without mock. This is a control for later tests.
        """

        # clear cache dir if any pickles are hanging around
        files = glob.glob("{0}/moonshot_*.pkl".format(TMP_DIR))
        for file in files:
            os.remove(file)

        class DecisionTreeML(MoonshotML):

            MODEL = self.pickle_path

            def prices_to_features(self, prices):
                features = {}
                features["feature1"] = prices.loc["Close"] > 10
                features["feature2"] = prices.loc["Close"] > 10 # silly, duplicate feature
                return features, None

            def predictions_to_signals(self, predictions, prices):
                # Go long when price is predicted to be below 10
                signals = predictions == 0
                return signals.astype(int)

        with self.assertRaises(ImproperlyConfigured) as cm:

            DecisionTreeML().backtest()

        self.assertIn("HOUSTON_URL is not set", repr(cm.exception))

    def test_20_cache_features(self):
        """
        Runs a strategy using mock to fill the features cache.
        """

        class DecisionTreeML(MoonshotML):

            MODEL = self.pickle_path

            def prices_to_features(self, prices):
                features = {}
                features["feature1"] = prices.loc["Close"] > 10
                features["feature2"] = prices.loc["Close"] > 10 # silly, duplicate feature
                return features, None

            def predictions_to_signals(self, predictions, prices):
                # Go long when price is predicted to be below 10
                signals = predictions == 0
                return signals.astype(int)

        def mock_get_prices(*args, **kwargs):

            dt_idx = pd.DatetimeIndex(["2018-05-01","2018-05-02","2018-05-03", "2018-05-04"])
            fields = ["Close","Volume"]
            idx = pd.MultiIndex.from_product([fields, dt_idx], names=["Field", "Date"])

            prices = pd.DataFrame(
                {
                    "FI12345": [
                        # Close
                        9,
                        11,
                        10.50,
                        9.99,
                        # Volume
                        5000,
                        16000,
                        8800,
                        9900
                    ],
                    "FI23456": [
                        # Close
                        9.89,
                        11,
                        8.50,
                        10.50,
                        # Volume
                        15000,
                        14000,
                        28800,
                        17000

                    ],
                 },
                index=idx
            )
            prices.columns.name = "Sid"

            return prices

        def mock_download_master_file(f, *args, **kwargs):

            master_fields = ["Timezone", "Symbol", "SecType", "Currency", "PriceMagnifier", "Multiplier"]
            securities = pd.DataFrame(
                {
                    "FI12345": [
                        "America/New_York",
                        "ABC",
                        "STK",
                        "USD",
                        None,
                        None
                    ],
                    "FI23456": [
                        "America/New_York",
                        "DEF",
                        "STK",
                        "USD",
                        None,
                        None,
                    ]
                },
                index=master_fields
            )
            securities.columns.name = "Sid"
            securities.T.to_csv(f, index=True, header=True)
            f.seek(0)

        with patch("moonshot.strategies.base.get_prices", new=mock_get_prices):
            with patch("moonshot.strategies.base.download_master_file", new=mock_download_master_file):
                results = DecisionTreeML().backtest(end_date="2018-05-04")

        self.assertSetEqual(
            set(results.index.get_level_values("Field")),
            {'Commission',
             'AbsExposure',
             'Signal',
             'Return',
             'Slippage',
             'NetExposure',
             'TotalHoldings',
             'Turnover',
             'AbsWeight',
             'Weight'}
        )

        # replace nan with "nan" to allow equality comparisons
        results = results.round(7)
        results = results.where(results.notnull(), "nan")

        signals = results.loc["Signal"].reset_index()
        signals.loc[:, "Date"] = signals.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            signals.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [1.0,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [1.0,
                     0.0,
                     1.0,
                     0.0]}
        )

        weights = results.loc["Weight"].reset_index()
        weights.loc[:, "Date"] = weights.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            weights.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.5,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [0.5,
                     0.0,
                     1.0,
                     0.0]}
        )

        abs_weights = results.loc["AbsWeight"].reset_index()
        abs_weights.loc[:, "Date"] = abs_weights.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            abs_weights.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.5,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [0.5,
                     0.0,
                     1.0,
                     0.0]}
        )

        net_positions = results.loc["NetExposure"].reset_index()
        net_positions.loc[:, "Date"] = net_positions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            net_positions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.0,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.0,
                     1.0]}
        )

        abs_positions = results.loc["AbsExposure"].reset_index()
        abs_positions.loc[:, "Date"] = abs_positions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            abs_positions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.0,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.0,
                     1.0]}
        )

        turnover = results.loc["Turnover"].reset_index()
        turnover.loc[:, "Date"] = turnover.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            turnover.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.5,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.5,
                     1.0]}
        )

        commissions = results.loc["Commission"].reset_index()
        commissions.loc[:, "Date"] = commissions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            commissions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     0.0,
                     0.0],
             "FI23456": [0.0,
                     0.0,
                     0.0,
                     0.0]}
        )

        slippage = results.loc["Slippage"].reset_index()
        slippage.loc[:, "Date"] = slippage.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            slippage.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     0.0,
                     0.0],
             "FI23456": [0.0,
                     0.0,
                     0.0,
                     0.0]}
        )

        returns = results.loc["Return"]
        returns = returns.reset_index()
        returns.loc[:, "Date"] = returns.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            returns.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     -0.0227273, # (10.50 - 11)/11 * 0.5
                     -0.0],
             "FI23456": [0.0,
                     0.0,
                     -0.1136364, # (8.50 - 11)/11 * 0.5
                     0.0]}
        )

        features_pickles = glob.glob("{0}/moonshot__features_*.pkl".format(TMP_DIR))
        self.assertEqual(len(features_pickles), 1)

    def test_30_load_features_from_cache(self):
        """
        Re-Runs the strategy without using mock to show that the history
        cache is used.
        """

        class DecisionTreeML(MoonshotML):

            MODEL = self.pickle_path

            def prices_to_features(self, prices):
                raise ValueError(
                    "in prices_to_features, but shouldn't have gotten here "
                    "because we should have loaded features from cache")

            def predictions_to_signals(self, predictions, prices):
                # Go long when price is predicted to be below 10
                signals = predictions == 0
                return signals.astype(int)

        results = DecisionTreeML().backtest(end_date="2018-05-04")

        self.assertSetEqual(
            set(results.index.get_level_values("Field")),
            {'Commission',
             'AbsExposure',
             'Signal',
             'Return',
             'Slippage',
             'NetExposure',
             'TotalHoldings',
             'Turnover',
             'AbsWeight',
             'Weight'}
        )

        # replace nan with "nan" to allow equality comparisons
        results = results.round(7)
        results = results.where(results.notnull(), "nan")

        signals = results.loc["Signal"].reset_index()
        signals.loc[:, "Date"] = signals.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            signals.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [1.0,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [1.0,
                     0.0,
                     1.0,
                     0.0]}
        )

        weights = results.loc["Weight"].reset_index()
        weights.loc[:, "Date"] = weights.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            weights.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.5,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [0.5,
                     0.0,
                     1.0,
                     0.0]}
        )

        abs_weights = results.loc["AbsWeight"].reset_index()
        abs_weights.loc[:, "Date"] = abs_weights.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            abs_weights.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.5,
                     0.0,
                     0.0,
                     1.0],
             "FI23456": [0.5,
                     0.0,
                     1.0,
                     0.0]}
        )

        net_positions = results.loc["NetExposure"].reset_index()
        net_positions.loc[:, "Date"] = net_positions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            net_positions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.0,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.0,
                     1.0]}
        )

        abs_positions = results.loc["AbsExposure"].reset_index()
        abs_positions.loc[:, "Date"] = abs_positions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            abs_positions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.0,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.0,
                     1.0]}
        )

        turnover = results.loc["Turnover"].reset_index()
        turnover.loc[:, "Date"] = turnover.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            turnover.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": ["nan",
                     0.5,
                     0.5,
                     0.0],
             "FI23456": ["nan",
                     0.5,
                     0.5,
                     1.0]}
        )

        commissions = results.loc["Commission"].reset_index()
        commissions.loc[:, "Date"] = commissions.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            commissions.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     0.0,
                     0.0],
             "FI23456": [0.0,
                     0.0,
                     0.0,
                     0.0]}
        )

        slippage = results.loc["Slippage"].reset_index()
        slippage.loc[:, "Date"] = slippage.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            slippage.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     0.0,
                     0.0],
             "FI23456": [0.0,
                     0.0,
                     0.0,
                     0.0]}
        )

        returns = results.loc["Return"]
        returns = returns.reset_index()
        returns.loc[:, "Date"] = returns.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            returns.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     -0.0227273, # (10.50 - 11)/11 * 0.5
                     -0.0],
             "FI23456": [0.0,
                     0.0,
                     -0.1136364, # (8.50 - 11)/11 * 0.5
                     0.0]}
        )

    def test_40_dont_use_cached_features_if_no_cache(self):
        """
        Runs a strategy with the no_cache=True to verify that the cached
        features are not used, and thus our CustomError is raised.
        """
        class CustomError(Exception):
            pass

        class DecisionTreeML(MoonshotML):

            MODEL = self.pickle_path

            def prices_to_features(self, prices):
                raise CustomError("this is a custom error")

            def predictions_to_signals(self, predictions, prices):
                # Go long when price is predicted to be below 10
                signals = predictions == 0
                return signals.astype(int)

        def mock_get_prices(*args, **kwargs):

            dt_idx = pd.DatetimeIndex(["2018-05-01","2018-05-02","2018-05-03", "2018-05-04"])
            fields = ["Close","Volume"]
            idx = pd.MultiIndex.from_product([fields, dt_idx], names=["Field", "Date"])

            prices = pd.DataFrame(
                {
                    "FI12345": [
                        # Close
                        9,
                        11,
                        10.50,
                        9.99,
                        # Volume
                        5000,
                        16000,
                        8800,
                        9900
                    ],
                    "FI23456": [
                        # Close
                        9.89,
                        11,
                        8.50,
                        10.50,
                        # Volume
                        15000,
                        14000,
                        28800,
                        17000

                    ],
                 },
                index=idx
            )
            prices.columns.name = "Sid"

            return prices

        def mock_download_master_file(f, *args, **kwargs):

            master_fields = ["Timezone", "Symbol", "SecType", "Currency", "PriceMagnifier", "Multiplier"]
            securities = pd.DataFrame(
                {
                    "FI12345": [
                        "America/New_York",
                        "ABC",
                        "STK",
                        "USD",
                        None,
                        None
                    ],
                    "FI23456": [
                        "America/New_York",
                        "DEF",
                        "STK",
                        "USD",
                        None,
                        None,
                    ]
                },
                index=master_fields
            )
            securities.columns.name = "Sid"
            securities.T.to_csv(f, index=True, header=True)
            f.seek(0)

        with patch("moonshot.strategies.base.get_prices", new=mock_get_prices):
            with patch("moonshot.strategies.base.download_master_file", new=mock_download_master_file):
                with self.assertRaises(CustomError) as cm:

                        DecisionTreeML().backtest(end_date="2018-05-04", no_cache=True)

        self.assertIn("this is a custom error", repr(cm.exception))

    def test_50_dont_use_cached_features_if_prices_change(self):
        """
        Re-runs the strategy after modifying the historical prices pickle to
        have a different index, which should trigger a cache miss for the
        features, causing the strategy to enter prices_to_features and raise
        CustomError.
        """
        history_pickles = glob.glob("{0}/moonshot__history_*.pkl".format(TMP_DIR))
        self.assertEqual(len(history_pickles), 1, msg="expected only 1 history pickle in cache dir")
        history_pickle_filename = history_pickles[0]

        with open(history_pickle_filename, "rb") as f:
            orig_prices = pickle.load(f)

        # drop a field and restore to disk
        prices = orig_prices.loc[["Close"]]
        with open(history_pickle_filename, "wb") as f:
            pickle.dump(prices, f)

        class CustomError(Exception):
            pass

        class DecisionTreeML(MoonshotML):

            MODEL = self.pickle_path

            def prices_to_features(self, prices):
                raise CustomError("in prices_to_features, good that we got here")

            def predictions_to_signals(self, predictions, prices):
                # Go long when price is predicted to be below 10
                signals = predictions == 0
                return signals.astype(int)

        with self.assertRaises(CustomError) as cm:

            DecisionTreeML().backtest(end_date="2018-05-04")

        self.assertIn("in prices_to_features", repr(cm.exception))

        # restore original prices pickle
        with open(history_pickle_filename, "wb") as f:
            pickle.dump(orig_prices, f)

    def test_60_load_features_from_cache_again(self):
        """
        Another control test to make sure the cache is being used again.
        """

        class DecisionTreeML(MoonshotML):

            MODEL = self.pickle_path

            def prices_to_features(self, prices):
                raise ValueError(
                    "in prices_to_features, but shouldn't have gotten here "
                    "because we should have loaded features from cache")

            def predictions_to_signals(self, predictions, prices):
                # Go long when price is predicted to be below 10
                signals = predictions == 0
                return signals.astype(int)

        results = DecisionTreeML().backtest(end_date="2018-05-04")

        self.assertSetEqual(
            set(results.index.get_level_values("Field")),
            {'Commission',
             'AbsExposure',
             'Signal',
             'Return',
             'Slippage',
             'NetExposure',
             'TotalHoldings',
             'Turnover',
             'AbsWeight',
             'Weight'}
        )

        returns = results.loc["Return"]
        returns = returns.reset_index()
        returns.loc[:, "Date"] = returns.Date.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
        self.assertDictEqual(
            returns.to_dict(orient="list"),
            {'Date': [
                '2018-05-01T00:00:00',
                '2018-05-02T00:00:00',
                '2018-05-03T00:00:00',
                '2018-05-04T00:00:00'],
             "FI12345": [0.0,
                     0.0,
                     -0.022727272727272707, # (10.50 - 11)/11 * 0.5
                     -0.0],
             "FI23456": [0.0,
                     0.0,
                     -0.11363636363636365, # (8.50 - 11)/11 * 0.5
                     0.0]}
        )

    def test_70_dont_use_cached_features_if_file_changes(self):
        """
        Re-runs the strategy after touching the present file, which should
        trigger a cache miss for the features, causing the strategy to enter
        prices_to_features and raise CustomError.
        """

        # Touch this file to simulate modifying it
        thisfile = inspect.getfile(self.__class__)
        Path(thisfile).touch()

        class CustomError2(Exception):
            pass

        class DecisionTreeML(MoonshotML):

            MODEL = self.pickle_path

            def prices_to_features(self, prices):
                raise CustomError2("in prices_to_features, good that we got here")

            def predictions_to_signals(self, predictions, prices):
                # Go long when price is predicted to be below 10
                signals = predictions == 0
                return signals.astype(int)

        with self.assertRaises(CustomError2) as cm:

            DecisionTreeML().backtest(end_date="2018-05-04")

        self.assertIn("in prices_to_features", repr(cm.exception))

        # Finally, remove cached files
        for file in glob.glob("{0}/moonshot*.pkl".format(TMP_DIR)):
            os.remove(file)
