from dataclasses import dataclass
from auto_chart_patterns.chart_pattern import ChartPattern, ChartPatternProperties
from auto_chart_patterns.line import Line, Pivot, Point
from auto_chart_patterns.zigzag import window_peaks
from typing import List
import pandas as pd
import logging

log = logging.getLogger(__name__)

@dataclass
class RsiDivergenceProperties(ChartPatternProperties):
    rsi_period: int = 14 # RSI period
    max_periods_lapsed: int = 30 # maximum number of days to form a pattern
    min_price_change_pct: float = 0.03 # minimum change percentage
    min_rsi_change_pct: float = 0.05 # minimum change percentage

class RsiDivergencePattern(ChartPattern):
    def __init__(self, pivots: List[Pivot], divergence_line: Line):
        self.pivots = pivots
        self.divergence_line = divergence_line
        self.extra_props = {}

    @classmethod
    def from_dict(cls, dict):
        self = cls(pivots=[Pivot.from_dict(p) for p in dict["pivots"]],
                   divergence_line=Line.from_dict(dict["divergence_line"]))
        self.pattern_type = dict["pattern_type"]
        self.pattern_name = dict["pattern_name"]
        return self

    def dict(self):
        obj = super().dict()
        obj["divergence_line"] = self.divergence_line.dict()
        return obj

    def get_pattern_name_by_id(self, id: int) -> str:
        pattern_names = {
            1: "Bullish",
            2: "Bearish",
            3: "Hidden Bullish",
            4: "Hidden Bearish",
        }
        return pattern_names[id]

    def get_change_direction(self, value1: float, value2: float,
                             min_change_pct: float) -> int:
        change_pct = (value2 - value1) / value1
        if change_pct > min_change_pct:
            return 1
        elif change_pct < -min_change_pct:
            return -1
        return 0

    def resolve(self, properties: RsiDivergenceProperties) -> 'RsiDivergencePattern':
        if len(self.pivots) != 2:
            raise ValueError("Rsi Divergence must have 2 pivots")
        self.pattern_type = 0

        # makes prices always greater than the rsi values
        price_change_dir = self.get_change_direction(self.pivots[0].point.price,
            self.pivots[1].point.price, properties.min_price_change_pct)
        rsi_change_dir = self.get_change_direction(self.divergence_line.p1.price,
            self.divergence_line.p2.price, properties.min_rsi_change_pct)

        log.debug(f"points: {self.pivots[0].point.time}, {self.pivots[1].point.time}, "
                  f"rsi: {self.divergence_line.p1.price}, {self.divergence_line.p2.price}, "
                  f"price_change_dir: {price_change_dir}, rsi_change_dir: {rsi_change_dir}")

        if price_change_dir == 1 and rsi_change_dir == -1:
            if self.pivots[0].direction > 0:
                # higher high but lower RSI
                self.pattern_type = 2 # bearish
            else:
                # higher low but lower RSI
                self.pattern_type = 3 # hidden bullish
        elif price_change_dir == -1 and rsi_change_dir == 1:
            if self.pivots[0].direction > 0:
                # lower high but higher RSI
                self.pattern_type = 4 # hidden bearish
            else:
                # lower low but higher RSI
                self.pattern_type = 1 # bullish

        if self.pattern_type != 0:
            self.pattern_name = self.get_pattern_name_by_id(self.pattern_type)
        return self

def calc_rsi(prices: pd.DataFrame, period: int) -> pd.Series:
    """Calculate RSI"""
    series = prices["close"]
    ewm = dict(alpha=1.0 / period, min_periods=period, adjust=True, ignore_na=True)
    diff = series.diff()
    ups = diff.clip(lower=0).ewm(**ewm).mean()
    downs = diff.clip(upper=0).abs().ewm(**ewm).mean()

    return 100.0 - (100.0 / (1.0 + ups / downs))

def handle_rsi_pivots(rsi_pivots: pd.DataFrame, is_high_pivots: bool,
                      properties: RsiDivergenceProperties, patterns: List[RsiDivergencePattern]):
    if is_high_pivots:
        rsi_col = 'rsi_high'
        price_col = 'high'
    else:
        rsi_col = 'rsi_low'
        price_col = 'low'

    for i in range(len(rsi_pivots)-1):
        current_row = rsi_pivots.iloc[i]
        next_row = rsi_pivots.iloc[i+1]
        current_index = current_row['row_number'].astype(int)
        next_index = next_row['row_number'].astype(int)
        if next_index - current_index + 1 > properties.max_periods_lapsed:
            continue

        point1 = Point(current_row.name, current_index,
                       current_row[rsi_col])
        point2 = Point(next_row.name, next_index,
                       next_row[rsi_col])
        divergence_line = Line(point1, point2)
        direction = 1 if is_high_pivots else -1
        price_pivots = [
            Pivot(
                Point(current_row.name, current_index,
                       current_row[price_col]),
                direction),
            Pivot(
                Point(next_row.name, next_index,
                       next_row[price_col]),
                direction)]
        pattern = RsiDivergencePattern(price_pivots, divergence_line).resolve(properties)
        if pattern.pattern_type != 0:
            patterns.append(pattern)

def find_rsi_divergences(backcandles: int, forwardcandles: int,
                         properties: RsiDivergenceProperties,
                         patterns: List[RsiDivergencePattern],
                         df: pd.DataFrame):
    """
    Find RSI divergences using zigzag pivots

    Args:
        backcandles: Number of backcandles
        forwardcandles: Number of forwardcandles
        properties: RSI divergence properties
        patterns: List to store found patterns
        df: DataFrame with prices
    """
    # calculate rsi
    rsi = calc_rsi(df, properties.rsi_period)
    # get rsi peaks
    rsi_highs, rsi_lows = window_peaks(rsi, backcandles, forwardcandles)
    rsi_high_pivots = rsi.where(rsi == rsi_highs)
    rsi_low_pivots = rsi.where(rsi == rsi_lows)
    # add row number
    df['row_number'] = pd.RangeIndex(len(df))

    # Merge for highs - including RSI values
    rsi_pivots= pd.merge(
        # Convert Series to DataFrame with column name
        pd.DataFrame({'rsi_high': rsi_high_pivots, 'rsi_low': rsi_low_pivots}),
        df[['row_number', 'high', 'low']],
        left_index=True,
        right_index=True,
        how='inner'
    )
    handle_rsi_pivots(
        rsi_pivots[['rsi_high', 'high','row_number']].dropna(),
        True, properties, patterns)
    handle_rsi_pivots(
        rsi_pivots[['rsi_low', 'low','row_number']].dropna(),
        False, properties, patterns)
