import pandas as pd
import os
import numpy as np
from matplotlib import pyplot as plt
from datetime import date

AbrDict = {'Ala':'ala',
        'Arg':'arg',
        'Asp':'asp',
        'Cys':'cys',
        'Glu':'glu',
        'Gly':'gly',
        'His':'his3',
        'Iso':'ile',
        'Leu':'leu',
        'Lys':'lys',
        'Met':'met',
        'Phe':'phe',
        'Pro':'pro',
        'Ser':'ser',
        'Thr':'thr',
        'Try':'trp',
        'Tyr':'tyr',
        'Val':'val'}
# based on Table 1 in Nanchen et al., 2007
FragRef = {'ala260':'1:2','ala232':'1:3',
'gly288':'1:2','gly246':'1:2','gly218':'2','gly144':'2',
'val288':'1:5', 'val260':'2:5','val302':'1:2',
'leu344':'1:6', 'leu274':'2:6', 'leu200':'2:6',
'ile344':'1:6', 'ile274':'2:6', 'ile200':'2:6',
'pro328':'1:5', 'pro286':'1:5', 'pro258':'2:5', 'pro184':'2:5',
'met320':'1:5', 'met292':'2:5', 'met218':'2:5',
'ser432':'1:3', 'ser390':'1:3', 'ser362':'2:3', 'ser288':'2:3','ser302':'1:2',
'thr446':'1:4', 'thr404':'1:4', 'thr376':'2:4',
'phe336':'1:9', 'phe308':'2:9', 'phe234':'2:9','phe302':'1:2',
'asp460':'1:4', 'asp418':'1:4', 'asp390':'2:4', 'asp316':'2:4','asp302':'1:2',
'glu474':'1:4', 'glu432':'1:4', 'glu404':'2:4', 'glu330':'2:4','glu302':'1:2',
'lys431':'1:6', 'lys329':'2:6',
'his482':'1:6', 'his440':'1:6', 'his412':'2:6', 'his338':'2:6','his302':'1:2',
'tyr508':'1:9', 'tyr466':'1:9', 'tyr438':'2:9', 'tyr364':'2:9','tyr302':'1:2',
}


def Xcal2iMS(XcalFile, Experiment=None, AAIgnore=None):
    '''Convert Xcalibur data to iMS2Flux input format.
    Parameters
    ----------
    XcalFile : str
        Path to Xcalibur file.
        This file is generated by the Xcalibur software and contains one sheet for each fragment, thus with hundreds of sheets.
        Each sheet contains the data for one fragment, e.g. 'Alanine_232'.
    Experiment : list, optional
        Name of the experiment to be used. The default is None, which corresponds to all experiments.
    AAIgnore : list, optional
        List of amino acids to be ignored. The default is None, which corresponds to no amino acids to be ignored.

    Returns
    -------
    outputfile_name : str
        Writes iMS2Flux tab-separated tsv-file.        
    '''
    Today = date.today().strftime('%y%m%d')
    # Extracting path and filename from XcalFile
    head, tail = os.path.split(XcalFile)

    xls = pd.ExcelFile(XcalFile)
    
    # Reading the first sheet to get the experiment identifiers
    myExp = pd.read_excel(xls, header=4, usecols='A')
    myDat = pd.DataFrame()
    for sheet in xls.sheet_names[:-1]:
        myDat[sheet] = pd.read_excel(xls, sheet, header=4, usecols='O', na_values=['NF'])
    
    # Select defined experiments
    if Experiment is not None:
        myExpList = myExp.iloc[:,0].isin(Experiment)
        myExp = myExp.loc[myExpList,:]
        myDat = myDat.loc[myExpList,:]


    # find columns that contain real data
    UseCols = myDat.notna().any()
    # find rows that contain real data
    UseRows = myDat.iloc[:,0].notna()
    myDat = myDat.loc[:, UseCols] 
    myDat = myDat.loc[UseRows,:]
    # get the fragment names
    myFrags_list = myDat.columns
    myDat = myDat.T
    myDat.reset_index(inplace=True, drop=True)
    # Delete rows that contain NaN in experiment identifiers
    myExp = myExp.loc[UseRows,:]
    myDat.columns = myExp.iloc[:,0]

    # Create new dataframe with fragment names and mass, using AbrDict from above...
    AAId = [AbrDict[Frag[:3]] for Frag in myFrags_list]
    Mass = [Frag[-3:] for Frag in myFrags_list]
    mynew = pd.DataFrame([AAId, Mass]).T
    myRes = pd.concat([mynew, myDat], axis=1)

    # Remove defined amino acid fragments
    if AAIgnore is not None:
        myResList = ~myRes.iloc[:,0].isin(AAIgnore)
        myRes = myRes.loc[myResList,:]
    myRes.reset_index(inplace=True, drop=True)


    # keep only the first name-occurence of each fragment
    repeat = myRes[0].drop_duplicates(keep='first')
    myRes[0] = ''
    myRes.iloc[repeat.index,0] = repeat
    myRes.rename(columns={0:'', 1:''}, inplace=True)

    # Write iMS2Flux file
    outputfile_name = os.path.join(head, 'iMS2Flux_' + tail[:-5] + '_' + Today + '.tsv')
    os.remove(outputfile_name) if os.path.exists(outputfile_name) else None
    with open(outputfile_name, 'a') as file:
        file.write('Response Values (Raw Data)')
        file.write('\n')
        myRes.to_csv(file, sep = '\t', index = False)
    return outputfile_name

def extract_AAcomplement(Test, FragRef=FragRef):
    '''
    Function to extract AA complement from a fragment
    '''
    # extracting AA ID from fragment
    AAID = ''.join([x for x in Test if x.isalpha()])
    # extracting AA specific fragments from dictionary FragRef
    AAFragRef = dict([(key, value) for key,value in FragRef.items() if key.startswith(AAID)])
    # removing test fragment from dictionary to check whether there is a fragment with the same backbone
    AAFragRef.pop(Test)
    # checking whether there is a fragment with the same backbone
    return([key for key, value in AAFragRef.items() if value == FragRef[Test]]) 
    
def Fragments2Backbone(myClass):
    '''
    Function to map fragments to backbone. In the process AA fragments with identical backbone are grouped together.
    '''
    # We check each fragment in the original list and add new fragment ID dynamically. Because some fragments have identical backbone, we need to remove both entries...
    FragReplaceNumb = len(myClass.List)
    Iterator = 0
    while Iterator<FragReplaceNumb: 
        Frag = myClass.List[0]
        # print('checking fragment: ', Frag, ' in condition: ', Condt)
        AAComplement = extract_AAcomplement(Frag, FragRef=FragRef)
        # generating new AA fragment ID
        AApure = ''.join([x for x in Frag if x.isalpha()])
        CBone = ''.join(x for x in map_Fragment2Backbone(Frag) if x.isdigit())
        IDnew = AApure + CBone
        myFrag = SubClass()
        if AAComplement and AAComplement[0] in myClass.List: # if there is a complement fragment with identical backbone, then average the MDV and Std
            MDVnew = np.mean(np.array([myClass.__dict__[Frag].MDV, myClass.__dict__[AAComplement[0]].MDV]), axis=0)
            Stdtmp = np.std(np.array([myClass.__dict__[Frag].MDV, myClass.__dict__[AAComplement[0]].MDV]), axis=0)
            Stdnew = np.max(np.array([myClass.__dict__[Frag].MDV, myClass.__dict__[AAComplement[0]].MDV, Stdtmp]), axis=0)
            myClass.delete_Fragment(AAComplement[0])
            Iterator += 1 # accounting for the deleted complementary fragment
        else:
            MDVnew = myClass.__dict__[Frag].MDV
            Stdnew = myClass.__dict__[Frag].Std
        myFrag.add_DictInfo({'ID': IDnew, 'MDV': MDVnew, 'Std': Stdnew, 'CBackbone': myClass.__dict__[Frag].CBackbone})
        myClass.add_Class(myFrag)
        myClass.delete_Fragment(Frag)
        Iterator += 1 # accounting for the deleted fragment

class SubClass:
    def __init__(self):
        self.MDV = None
        self.Std = None
        self.CBackbone = None
        self.ID = None
    def add_DictInfo(self, Dict):
        for key, val in Dict.items():
            self.__dict__[key] = val
    def add_MDV(self, MDV):
        self.MDV = MDV
    def add_Std(self, Std):
        self.Std = Std
    def add_CBackbone(self, CBackbone):
        self.CBackbone = CBackbone
class NestedClass:
    def __init__(self, time=None):
        self.List = []
        self.Time = time
        self.Conditions = []
    def set_Conditions(self, conditions):
        for val in conditions:
            self.__dict__[val] = NestedClass()
        self.Conditions = conditions
    def set_Time(self, time):
        self.Time = time
    def add_Class(self, Class):
        self.List.append(getattr(Class, 'ID'))
        self.__dict__[getattr(Class, 'ID')] = Class
    def add_Condition(self, Condition, Class):
        self.List.append(getattr(Class, 'ID'))
        self.__dict__[Condition].__dict__[getattr(Class, 'ID')] = Class
    def delete_Fragment(self, fragment):
        self.__dict__.pop(fragment)
        self.List.remove(fragment)
    def convert_Fragments2Backbone(self):
        Fragments2Backbone(self)
    # def add_Arg(self, argv):
    #     self.List.append(argv['ID'])
    #     self.__dict__[argv['ID']] = SubClass(argv)
    def plot_MDVtime(self, fragment_list):
        # x=np.zeros([len(self.time), len(fragment_list)])
        for fragment in fragment_list:
            if fragment in self.List:
                [plt.errorbar(Exp.time, rows[0], yerr=rows[1], label=f'{fragment}-M{rows[2]}') for rows in zip(getattr(self, fragment).MDV, getattr(self, fragment).RSD, range(getattr(self, fragment).RSD.shape[0]))] 
                plt.legend()
            else:
                print(f"Class does not have a variable named: {fragment}")
                continue
                # raise(f"Class does not have a variable named: {fragment}")
        plt.xlabel('Time (h)')
        plt.ylabel('Fragment MDV')
        plt.show()
    def plot_MDVtime(self, fragment_list):
        # x=np.zeros([len(self.time), len(fragment_list)])
        for fragment in fragment_list:
            if fragment in self.List:
                plt.errorbar(self.time, getattr(self, fragment).MDV.transpose(), getattr(self, fragment).Std.transpose(), label=fragment)
                plt.legend()
            else:
                print(f"Class does not have a variable named: {fragment}")
                continue
                # raise(f"Class does not have a variable named: {fragment}")
        plt.xlabel('Time (h)')
        plt.ylabel('Average Labeling')
        plt.show()
    def remove_TimeSample(self, time_index):
        for fragment in self.List:
            self.__dict__[fragment].MDV = np.delete(self.__dict__[fragment].MDV, time_index, 1)
            self.__dict__[fragment].AvgMDV = np.delete(self.__dict__[fragment].AvgMDV, time_index, 0)
            self.__dict__[fragment].RSD = np.delete(self.__dict__[fragment].RSD, time_index, 1)
            self.__dict__[fragment].AvgMDVrsd = np.delete(self.__dict__[fragment].AvgMDVrsd, time_index, 0)
        self.Time = np.delete(self.time, time_index, 0)
class MultiCondition():
    def __init__(self):
        self.List = []
        self.Conditions = []
        self.Time = None
    def set_Conditions(self, conditions):
        for val in conditions:
            self.__dict__[val] = None
        self.Conditions = conditions
    def set_Time(self, time):
        self.Time = time
    def add_Condition(self, Condition, Class):
        if Condition in self.Conditions:
            self.__dict__[Condition] = Class
        else:
            print(f"Class does not have a variable named: {Condition}")
            # raise(f"Class does not have a variable named: {Condition}")
    def plot_MDVtime(self, condition_list, fragment_list):
        for condition in condition_list:
            if condition in self.Conditions:
                for fragment in fragment_list:
                    if fragment in self.__dict__[condition].List:
                        yval = getattr(self.__dict__[condition], fragment).MDV
                        yerr = getattr(self.__dict__[condition], fragment).Std
                        [plt.errorbar(self.Time, rows[0], yerr=rows[1], label=f'{condition}-{fragment}-M{rows[2]}') for rows in zip(yval, yerr, range(yval.shape[0]))] 
                        # plt.errorbar(self.__dict__[condition].Time, getattr(self.__dict__[condition], fragment).MDV.transpose(), getattr(self.__dict__[condition], fragment).Std.transpose(), label=f'{condition}-{fragment}')
                        plt.legend()
                    else:
                        print(f"Class does not have a variable named: {fragment}")
                        continue
                        # raise(f"Class does not have a variable named: {fragment}")
            else:
                print(f"Class does not have a variable named: {condition}")
                continue
                # raise(f"Class does not have a variable named: {condition}")
        plt.xlabel('Time (h)')
        plt.ylabel('Average Labeling')
        plt.show()
    def convert_Fragments2Backbone(self):
        for condition in self.Conditions:
            self.__dict__[condition].convert_Fragments2Backbone()

def map_Fragment2Backbone(ID: str, FragRef: dict = FragRef):
    """
    Map Fragment ID to Backbone ID
    """
    # Table 1 from Nanchen et al., 2007
    Mass = int(''.join(x for x in ID if x.isdigit()))
    myDict = {}
    # if Mass == 302:
    #     CBone = '1:2'
    # else:
    #     CBone = FragRef[ID]    
    return FragRef[ID]
def find_IdentCbone(Backbone, Fragments):
    """
    Find the fragments that have the same backbone
    """
    IdentCbone = []
    for frag in Fragments:
        if map_Fragment2Backbone(frag) == Backbone:
            IdentCbone.append(frag)
    return IdentCbone