from .__main__ import Strategy
from array import array
from .computation import dif, ema
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from datetime import datetime, timedelta
from pytz import timezone

class Momentum_settings:
    '''
    if you found it is more likely to enter "long stage"
        - you should set higher threshold for short, but lower threshold for long
    '''
    long_buy_thres = 0
    long_sell_thres = 0
    short_buy_thres = 6
    short_sell_thres = 6
    ema_smooth_factor = 5 # long-term indicator as ema smooth factor
    # if larger, slower track of current price, more consistent
    dif_smooth_factor = 2 # price smooth factor for computing momentum
    
    
    

class Momentum(Strategy):
    '''
    ema enhanced momentum strategy
    '''
    def __init__(self, signal_df, momemtum_settings, verbose=False, only_morning=False):
        '''
        Purpose:
            input data, output signal(long? short? empty?) to the corresponding **time period** 
                - **time period** represents as its start time point, datetime.datetime
        
        Args:
            signal_df: contains 2 columns:
                - ['time']: datetime.datetime
                - ['price']: price wrt the time point
            
            momentum_setting:
        
        WARINING:
            - when backtest, price = the price of corresponding timepoint (use Open)
            - when real-time, price = the price of corresponding timepoint + 1 interval (use Close)
        '''
        
        super().__init__('momentum')
        self.only_morning = only_morning
        self.signal_df = signal_df
        self.momemtum_settings = momemtum_settings
        self.verbose = verbose
        #self.signal_df['pydate'] = signal_df.apply(lambda row: self.convert_time(row, 'Date'), axis=1)
        #self.long_asset_prices = long_asset_df['Close'].values
        
        ####### main process
        # create long term indicator
        self.signal_df['ema']  = self.long_term_momentum()
        # create short term momentum 
        differentials = self.compute_momentum(self.momemtum_settings.dif_smooth_factor)
        self.signal_df['dif'] = np.nan
        self.signal_df['dif'][1:] = differentials
        # create signal
        self.signal_df['signal'] = signal_df.apply(lambda row: self.create_signal_row(row), axis=1)
        
        if verbose:
            self.plot_strategy()
            
 
    
    def plot_strategy(self):


        figure_size  = (20, 10)
        fig, ax = plt.subplots(figsize=figure_size)
        ax.plot(self.signal_df['price'], label="signal price")
        #if emas != None:
        ax.plot(self.signal_df['ema'], label=f'ema{self.momemtum_settings.ema_smooth_factor} of signal price')
        # print long
        ax.vlines(x= np.array(self.signal_df[self.signal_df['signal'] == 'long'].index), 
                  ymin=np.min(self.signal_df['price']), 
                  ymax=np.max(self.signal_df['price']), 
                  color='g', linestyle='-.')
        # print short
        ax.vlines(x= np.array(self.signal_df[self.signal_df['signal'] == 'short'].index), 
                  ymin=np.min(self.signal_df['price']), 
                  ymax=np.max(self.signal_df['price']), 
                  color='r', linestyle='-.')
        
        ax.legend()
        #pass
    
    
    def convert_time(self, row, colname='Date'):
        nyc = timezone('America/New_York')
        return row[colname].to_pydatetime().astimezone(nyc)
    
    def create_signal_row(self, row):
        price = row['price']
        dif = row['dif']
        ema = row['ema']
        time = row['time']
        
        if self.only_morning == False:
            if dif == np.nan:
                return 'empty'
            else:
                if dif > self.momemtum_settings.long_buy_thres and price > ema: # long
                    return 'long'
                elif dif < -self.momemtum_settings.short_buy_thres and price < ema:
                    return 'short'
                else:
                    return 'empty'
        else:
            if dif == np.nan:
                return 'empty'
            else:
                if dif > self.momemtum_settings.long_buy_thres and price > ema and int(time.strftime('%H')) < 10: # long
                    return 'long'
                elif dif < -self.momemtum_settings.short_buy_thres and price < ema and int(time.strftime('%H')) < 10:
                    return 'short'
                else:
                    return 'empty'
            
            

        
        
    def long_term_momentum(self):
        '''
        use ema20 as default for minites trade
        '''
        return ema(self.signal_df['price'], self.momemtum_settings.ema_smooth_factor)
    
    
    def compute_momentum(self, dif_sf, focal_row_nums=[]):
        '''
        1st differential
        
        - possible improvement:
            - 1st differential is related to the abosolute index value
            - so divided by index value  / self.prices[-1]
        '''
        ma_dif = dif(ema(self.signal_df['price'], dif_sf)) 
        assert len(ma_dif) == self.signal_df.shape[0] - 1
        

        if self.verbose:
            
            # #if you want to plot prices in a seperated plot, uncomment it.
            # fig, ax = plt.subplots(figsize=figure_size)
            # ax.plot(ori_data, label="Price")
            
            # if len(focal_row_nums) != 0:
            #     for focal_row_num in focal_row_nums:
            #         ax.vlines(ymin=np.min(ori_data), ymax=np.max(ori_data), x=focal_row_num, color='r', linestyle='-.')
            # ax.legend()
            
            figure_size  = (20, 10)
            fig, ax = plt.subplots(figsize=figure_size)
            #ax.plot(dif, label="dif: divergence")
            ax.hlines(xmin=0, xmax=ma_dif.shape[0], y = 0, color='r', linestyle='-.')
            #assert dif.shape[0] == ori_data.shape[0] == ma_dif.shape[0]
            for focal_row_num in focal_row_nums:
                ax.vlines(ymin=np.min(ma_dif), ymax=np.max(ma_dif), x=focal_row_num, color='r', linestyle='-.')
            ax.plot(ma_dif, label="MA of delta price: denoised(smoothed) divergence")
            #ax.plot(ema_ema, label="EMA of EMA")
            #ax.plot(macd, label="macd")
            ax.legend()
        
        return ma_dif
        
        
        
    

        
        
        
### old
                
    
    # def create_signal(self):
        
        
    #     # no data for first differential for 1st data point
    #     emas = self.long_term_indicator
    #     prices = self.signal_df['Close'].values[1:]
    #     times = self.signal_df['pydate'].values[1:]
    #     emas = emas[1:]
        
    #     # create signal
    #     delta_diver = self.differentials
    #     long = []
    #     short = []
    #     empty = []
    #     for i, value in enumerate(delta_diver):
            
    #         if i==0: # when time =0, only consider when to buy
    #             if value > self.momemtum_settings.long_buy_thres: # buy long 
    #                 long.append(1)
    #                 short.append(0)
    #                 empty.append(0)
    #             elif value < - self.momemtum_settings.short_buy_thres: # buy short
    #                 long.append(0)
    #                 short.append(1)
    #                 empty.append(0)
    #             else:
    #                 long.append(0)
    #                 short.append(0)
    #                 empty.append(1)
    #         else: # when time >0, obtain prev status and consider current operation
    #             if long[i-1] == 1: # if prev hold long, consider when to sell long
    #                 if value < self.momemtum_settings.long_sell_thres: # if <= long_sell_thres, sell
    #                     if value < - self.momemtum_settings.short_buy_thres: # if buy short instead? < -thres
    #                         long.append(0)
    #                         short.append(1)
    #                         empty.append(0)
    #                     else:
    #                         long.append(0)
    #                         short.append(0)
    #                         empty.append(1)   
    #                 else: # continue to hold
    #                     long.append(1)
    #                     short.append(0)
    #                     empty.append(0)
                        
    #             elif short[i-1] == 1: # if prev hold short
    #                 if value > - self.momemtum_settings.short_sell_thres: # sell short
    #                     if value > self.momemtum_settings.long_buy_thres: # buy long instead?
    #                         long.append(1)
    #                         short.append(0)
    #                         empty.append(0)
    #                     else:
    #                         long.append(0)
    #                         short.append(0)
    #                         empty.append(1)
    #                 else: # continue to hold
    #                     long.append(0)
    #                     short.append(1)
    #                     empty.append(0)
                
    #             elif empty[i-1] == 1: # prev did not hold anything
    #                 if value > self.momemtum_settings.long_buy_thres:
    #                     long.append(1)
    #                     short.append(0)
    #                     empty.append(0)
    #                 elif value < - self.momemtum_settings.short_buy_thres:
    #                         long.append(0)
    #                         short.append(1)
    #                         empty.append(0)
    #                 else:
    #                     long.append(0)
    #                     short.append(0)
    #                     empty.append(1)

    #         if emas[i] > prices[i]: # go down trend
    #             if long[i] == 1:
    #                 long[i] = 0
    #                 empty[i]=1
    #         elif emas[i] <= prices[i]:
    #             if short[i] == 1:
    #                 short[i] = 0
    #                 empty[i]=1
                             
                
        
    #     assert len(prices) == len(emas) == len(long)
    #     for long_, short_, empty_ in zip(long, short, empty):
    #         assert ((long_ + short_ + empty_) == 1)

    #     signals = {'long':long, 'short':short, 'empty':empty, 'signal_time':times}
        
        
        
        # return signals # last signal is useless   