import re

from Bio import SeqIO

from pypgatk.cgenomes.models import SNP
from pypgatk.toolbox.general import ParameterConfiguration


class CancerGenomesService(ParameterConfiguration):
    CONFIG_CANCER_GENOMES_MUTATION_FILE = 'mutation_file'
    CONFIG_COMPLETE_GENES_FILE = "all_cds_genes_file"
    CONFIG_OUTPUT_FILE = "output_file"
    CONFIG_COSMIC_DATA = "cosmic_data"
    CONFIG_KEY_DATA = 'proteindb'
    CONFIG_FILTER_INFO = 'filter_info'
    FILTER_COLUMN = "filter_column"
    ACCEPTED_VALUES = "accepted_values"
    SPLIT_BY_FILTER_COLUMN = "split_by_filter_column"
    CLINICAL_SAMPLE_FILE = 'clinical_sample_file'
    CONFIG_COSMIC_SERVER = 'cosmic_server'

    def __init__(self, config_file, pipeline_arguments):
        """
        Init the class with the specific parameters.
        :param config_file configuration file
        :param pipeline_arguments pipelines arguments
        """
        super(CancerGenomesService, self).__init__(self.CONFIG_COSMIC_DATA, config_file, pipeline_arguments)

        self._filter_column = self.get_mutations_default_options(variable=self.FILTER_COLUMN,
                                                                 default_value='CANCER_TYPE')
        self._accepted_values = self.get_multiple_options(self.get_mutations_default_options(variable=self.ACCEPTED_VALUES, default_value="all"))

        self._split_by_filter_column = self.get_mutations_default_options(variable=self.SPLIT_BY_FILTER_COLUMN,
                                                                          default_value=False)
        self._local_clinical_sample_file = self.get_mutations_default_options(variable=self.CLINICAL_SAMPLE_FILE,
                                                                              default_value='')

        self._local_mutation_file = 'CosmicMutantExport.tsv.gz'
        if self.CONFIG_CANCER_GENOMES_MUTATION_FILE in self.get_pipeline_parameters():
            self._local_mutation_file = self.get_pipeline_parameters()[self.CONFIG_CANCER_GENOMES_MUTATION_FILE]
        elif self.CONFIG_COSMIC_DATA in self.get_default_parameters() and \
                self.CONFIG_COSMIC_SERVER in self.get_default_parameters()[self.CONFIG_COSMIC_DATA] and \
                self.CONFIG_CANCER_GENOMES_MUTATION_FILE in self.get_default_parameters()[self.CONFIG_COSMIC_DATA][
            self.CONFIG_COSMIC_SERVER]:
            self._local_mutation_file = \
                self.get_default_parameters()[self.CONFIG_COSMIC_DATA][self.CONFIG_COSMIC_SERVER][
                    self.CONFIG_CANCER_GENOMES_MUTATION_FILE]

        self._local_complete_genes = 'All_COSMIC_Genes.fasta.gz'
        if self.CONFIG_COMPLETE_GENES_FILE in self.get_pipeline_parameters():
            self._local_complete_genes = self.get_pipeline_parameters()[self.CONFIG_COMPLETE_GENES_FILE]
        elif self.CONFIG_COSMIC_DATA in self.get_default_parameters() and \
                self.CONFIG_COSMIC_SERVER in self.get_default_parameters()[self.CONFIG_COSMIC_DATA] and \
                self.CONFIG_COMPLETE_GENES_FILE in self.get_default_parameters()[self.CONFIG_COSMIC_DATA][
            self.CONFIG_COSMIC_SERVER]:
            self._local_complete_genes = \
                self.get_default_parameters()[self.CONFIG_COSMIC_DATA][self.CONFIG_COSMIC_SERVER][
                    self.CONFIG_COMPLETE_GENES_FILE]

        self._local_output_file = 'output_database.fasta'
        if self.CONFIG_OUTPUT_FILE in self.get_pipeline_parameters():
            self._local_output_file = self.get_pipeline_parameters()[self.CONFIG_OUTPUT_FILE]

    def get_mutations_default_options(self, variable: str, default_value):
        return_value = default_value
        if variable in self.get_pipeline_parameters():
            return_value = self.get_pipeline_parameters()[variable]
        elif self.CONFIG_KEY_DATA in self.get_default_parameters() \
                and self.CONFIG_FILTER_INFO in self.get_default_parameters()[self.CONFIG_KEY_DATA] \
                and variable in self.get_default_parameters()[
            self.CONFIG_KEY_DATA][self.CONFIG_FILTER_INFO]:
            return_value = self.get_default_parameters()[self.CONFIG_KEY_DATA][self.CONFIG_FILTER_INFO][variable]
        return return_value

    @staticmethod
    def get_multiple_options(options_str: str):
        """
        This method takes an String like option1, option2, ... and produce and array [option1, option2,... ]
        :param options_str:
        :return: Array
        """
        return list(map(lambda x: x.strip(), options_str.split(",")))

    @staticmethod
    def get_mut_pro_seq(snp, seq):
        nucleotide = ["A", "T", "C", "G"]
        mut_pro_seq = ""
        if "?" not in snp.dna_mut and snp.aa_mut != 'p.?':  # unambiguous DNA change known in CDS sequence
            positions = re.findall(r'\d+', snp.dna_mut)
            if ">" in snp.dna_mut and len(positions) == 1:  # Substitution
                tmplist = snp.dna_mut.split(">")
                ref_dna = re.sub("[^A-Z]+", "", tmplist[0])
                mut_dna = re.sub("[^A-Z]+", "", tmplist[1])
                index = int(positions[0]) - 1
                if ref_dna == str(seq[index]).upper() and mut_dna in nucleotide:  #
                    seq_mut = seq[:index] + mut_dna + seq[index + 1:]
                    mut_pro_seq = seq_mut.translate(to_stop=False)
            elif "ins" in snp.dna_mut:
                index = snp.dna_mut.index("ins")
                insert_dna = snp.dna_mut[index + 3:]
                if insert_dna.isalpha():
                    ins_index1 = int(positions[0])
                    seq_mut = seq[:ins_index1] + insert_dna + seq[ins_index1:]
                    mut_pro_seq = seq_mut.translate(to_stop=False)

            elif "del" in snp.dna_mut:
                if len(positions) == 2:
                    del_index1 = int(positions[0]) - 1
                    del_index2 = int(positions[1])
                    seq_mut = seq[:del_index1] + seq[del_index2:]
                    mut_pro_seq = seq_mut.translate(to_stop=False)
                elif len(positions) == 1:
                    del_index1 = int(positions[0]) - 1
                    seq_mut = seq[:del_index1] + seq[del_index1 + 1:]
                    mut_pro_seq = seq_mut.translate(to_stop=False)
        else:
            if "?" not in snp.aa_mut:  # unambiguous aa change known in protein sequence
                positions = re.findall(r'\d+', snp.aa_mut)
                protein_seq = str(seq.translate(to_stop=False))

                if "Missense" in snp.type:
                    mut_aa = snp.aa_mut[-1]
                    index = int(positions[0]) - 1
                    mut_pro_seq = protein_seq[:index] + mut_aa + protein_seq[index + 1:]
                elif "Nonsense" in snp.type:
                    index = int(positions[0]) - 1
                    mut_pro_seq = protein_seq[:index]
                elif "Insertion - In frame" in snp.type:
                    index = snp.aa_mut.index("ins")
                    insert_aa = snp.aa_mut[index + 3:]
                    if insert_aa.isalpha():
                        ins_index1 = int(positions[0])
                        mut_pro_seq = protein_seq[:ins_index1] + insert_aa + protein_seq[ins_index1:]
                elif "Deletion - In frame" in snp.type:
                    if len(positions) == 2:
                        del_index1 = int(positions[0]) - 1
                        del_index2 = int(positions[1])
                        mut_pro_seq = protein_seq[:del_index1] + protein_seq[del_index2:]
                    elif len(positions) == 1:
                        del_index1 = int(positions[0]) - 1
                        mut_pro_seq = protein_seq[:del_index1] + protein_seq[del_index1 + 1:]
                elif "Complex" in snp.type and "frameshift" not in snp.type:
                    try:
                        index = snp.aa_mut.index(">")
                    except ValueError:
                        return ''
                    mut_aa = snp.aa_mut[index + 1:]
                    if "deletion" in snp.type:
                        del_index1 = int(positions[0]) - 1
                        del_index2 = int(positions[1])
                        mut_pro_seq = protein_seq[:del_index1] + mut_aa + protein_seq[del_index2:]

                    elif "insertion" in snp.type:
                        ins_index1 = int(positions[0]) - 1
                        mut_pro_seq = protein_seq[:ins_index1] + mut_aa + protein_seq[ins_index1 + 1:]
                    elif "compound substitution" in snp.type:
                        if "*" not in mut_aa:
                            del_index1 = int(positions[0]) - 1
                            del_index2 = int(positions[1])
                            mut_pro_seq = protein_seq[:del_index1] + mut_aa + protein_seq[del_index2:]
                        else:
                            del_index1 = int(positions[0]) - 1
                            mut_pro_seq = protein_seq[:del_index1] + mut_aa.replace("*", "")

        return mut_pro_seq

    def cosmic_to_proteindb(self):
        """
        This function translates the mutation file + COSMIC genes into a protein Fasta database. The
        method writes into the file system the output Fasta.
        :return:
        """
        self.get_logger().debug("Starting reading All cosmic genes")
        COSMIC_CDS_DB = {}
        for record in SeqIO.parse(self._local_complete_genes, 'fasta'):
            try:
                COSMIC_CDS_DB[record.id].append(record)
            except KeyError:
                COSMIC_CDS_DB[record.id] = [record]

        cosmic_input = open(self._local_mutation_file, encoding="latin-1")

        header = cosmic_input.readline().split("\t")
        regex = re.compile('[^a-zA-Z]')
        gene_col = header.index("Gene name")
        enst_col = header.index("Accession Number")
        cds_col = header.index("Mutation CDS")
        aa_col = header.index("Mutation AA")
        muttype_col = header.index("Mutation Description")
        filter_col = None
        if self._filter_column:
            filter_col = header.index(self._filter_column)

        output = open(self._local_output_file, 'w')

        mutation_dic = {}
        groups_mutations_dict = {}
        self.get_logger().debug("Reading input CosmicMutantExport.tsv ...")
        line_counter = 1
        for line in cosmic_input:
            if line_counter % 10000 == 0:
                msg = "Number of lines finished -- '{}'".format(line_counter)
                self.get_logger().debug(msg)
            line_counter += 1
            row = line.strip().split("\t")
            # filter out mutations from unspecified groups
            if filter_col is not None:
                if row[filter_col] not in self._accepted_values and self._accepted_values != ['all']:
                    continue

            if "coding silent" in row[muttype_col]:
                continue

            snp = SNP(gene=row[gene_col], mrna=row[enst_col], dna_mut=row[cds_col], aa_mut=row[aa_col],
                      type=row[muttype_col])
            header = "COSMIC:%s:%s:%s" % (snp.gene, snp.aa_mut, snp.type.replace(" ", ""))
            try:
                this_gene_records = COSMIC_CDS_DB[snp.gene]
                seqs = []
                for record in this_gene_records:
                    seqs.append(record.seq)

            except KeyError:  # geneID is not in All_COSMIC_Genes.fasta
                continue

            mut_pro_seq = None
            for seq in seqs:
                try:
                    mut_pro_seq = self.get_mut_pro_seq(snp, seq)
                except IndexError:
                    continue
                if mut_pro_seq:
                    break

            if mut_pro_seq:
                entry = ">%s\n%s\n" % (header, mut_pro_seq)
                if header not in mutation_dic:
                    output.write(entry)
                    mutation_dic[header] = 1

                if self._split_by_filter_column and filter_col is not None:
                    try:
                        groups_mutations_dict[row[filter_col]][header] = entry
                    except KeyError:
                        groups_mutations_dict[row[filter_col]] = {header: entry}

        for group_name in groups_mutations_dict.keys():
            with open(self._local_output_file.replace('.fa', '') + '_' + regex.sub('', group_name) + '.fa', 'w') as fn:
                for header in groups_mutations_dict[group_name].keys():
                    fn.write(groups_mutations_dict[group_name][header])

        self.get_logger().debug("COSMIC contains in total {} non redundant mutations".format(len(mutation_dic)))
        cosmic_input.close()
        output.close()

    @staticmethod
    def get_sample_headers(header_line, filter_coumn):
        try:
            filter_col = header_line.index(filter_coumn)
        except ValueError:
            print(filter_coumn, ' was not found in the header row:', header_line)
            return None, None
        try:
            sample_id_col = header_line.index('SAMPLE_ID')
        except ValueError:
            print('SAMPLE_ID was not found in the header row:', header_line)
            return None, None
        return filter_col, sample_id_col

    def get_value_per_sample(self, local_clinical_sample_file, filter_column):
        sample_value = {}
        if local_clinical_sample_file:
            with open(local_clinical_sample_file, 'r') as clin_fn:
                filter_column_col, sample_id_col = None, None
                for line in clin_fn.readlines():
                    if line.startswith('#'):
                        continue
                    sl = line.strip().split('\t')
                    # check for header and re-assign columns
                    if 'SAMPLE_ID' in sl and filter_column in sl:
                        filter_column_col, sample_id_col = self.get_sample_headers(sl, filter_column)
                    if filter_column_col is not None and sample_id_col is not None:
                        sample_value[sl[sample_id_col]] = sl[filter_column_col].strip().replace(' ', '_')
                    else:
                        print("No column was found for {}, {} in {}".format(filter_column, 'SAMPLE_ID',
                                                                            local_clinical_sample_file))
        return sample_value

    @staticmethod
    def get_mut_header_cols(header_cols, row):
        for col in header_cols.keys():
            header_cols[col] = row.index(col)

        return header_cols

    def cbioportal_to_proteindb(self):
        """cBioportal studies have a data_clinical_sample.txt file
    that shows the Primary Tumor Site per Sample Identifier
    The sample ID in the clinical file matches Tumor_Sample_Barcode column in the mutations file.
    """
        regex = re.compile('[^a-zA-Z]')
        sample_groups_dict = {}
        group_mutations_dict = {}
        seq_dic = {}

        fafile = SeqIO.parse(self._local_complete_genes, "fasta")
        for record in fafile:
            newacc = record.id.split(".")[0]
            if newacc not in seq_dic:
                seq_dic[newacc] = record.seq

        header_cols = {"HGVSc": None, "Transcript_ID": None, "Variant_Classification": None,
                       "Variant_Type": None, "HGVSp_Short": None, 'Tumor_Sample_Barcode': None}
        nucleotide = ["A", "T", "C", "G"]
        mutclass = ["Frame_Shift_Del", "Frame_Shift_Ins", "In_Frame_Del", "In_Frame_Ins", "Missense_Mutation",
                    "Nonsense_Mutation"]

        # check if sample id and clinical files are given, if not and not filter is required then exit
        if self._accepted_values != ['all'] or self._split_by_filter_column:
            if self._local_clinical_sample_file:
                sample_groups_dict = self.get_value_per_sample(self._local_clinical_sample_file, self._filter_column)
                print('sample_groups_dict', self._local_clinical_sample_file, self._filter_column)
                if sample_groups_dict == {}:
                    return
            else:
                print('No clinical sample file is given therefore no filter could be applied.')
                return

        with open(self._local_mutation_file, "r") as mutfile, open(self._local_output_file, "w") as output:
            for i, line in enumerate(mutfile):
                row = line.strip().split("\t")
                if row[0] == '#':
                    print("skipping line ({}): {}".format(i, row))
                    continue
                # check for header in the mutations file and get column indices
                if set(header_cols.keys()).issubset(set(row)):
                    header_cols = self.get_mut_header_cols(header_cols, row)

                # check if any is none in header_cols then continue
                if None in header_cols.values():
                    print("Incorrect header column is given")
                    continue
                # get filter value and check it
                group = None
                if self._accepted_values != ['all'] or self._split_by_filter_column:
                    try:
                        group = sample_groups_dict[row[header_cols['Tumor_Sample_Barcode']]]
                    except KeyError:
                        if self._accepted_values != ['all'] or self._split_by_filter_column:
                            print("No clinical info was found for sample {}. Skipping (line {}): {}".format(
                                row[header_cols['Tumor_Sample_Barcode']], i, line))
                            continue
                    except IndexError:
                        print("No sampleID was found in (line {}): {}".format(i, row))
                if group not in self._accepted_values and self._accepted_values != ['all']:
                    continue

                gene = row[0]
                try:
                    pos = row[header_cols["HGVSc"]]
                    enst = row[header_cols["Transcript_ID"]]

                    seq_mut = ""
                    aa_mut = row[header_cols["HGVSp_Short"]]

                    vartype = row[header_cols["Variant_Type"]]
                    varclass = row[header_cols["Variant_Classification"]]
                except IndexError:
                    print("Incorrect line (i):", row)
                    continue
                if varclass not in mutclass:
                    continue

                try:
                    seq = seq_dic[enst]
                except KeyError:
                    print("No matching recored for gene ({}) from row {} in FASTA file:".format(enst, row))
                    continue

                if ":" in pos:
                    cdna_pos = pos.split(":")[1]
                else:
                    cdna_pos = pos

                if vartype == "SNP":
                    try:
                        enst_pos = int(re.findall(r'\d+', cdna_pos)[0])
                    except IndexError:
                        print("Incorrect SNP format or record", i, pos, line)
                        continue
                    idx = pos.index(">")
                    ref_dna = pos[idx - 1]
                    mut_dna = pos[idx + 1]

                    if mut_dna not in nucleotide:
                        print(mut_dna, "is not a nucleotide base", pos)
                        continue
                    try:
                        if ref_dna == seq[enst_pos - 1]:
                            seq_mut = seq[:enst_pos - 1] + mut_dna + seq[enst_pos:]
                        else:
                            print("incorrect substitution, unmatched nucleotide", pos, enst)
                    except IndexError:
                        print("incorrect substitution, out of index", pos)
                elif vartype == "DEL":
                    try:
                        enst_pos = int(re.findall(r'\d+', cdna_pos.split("_")[0])[0])
                    except IndexError:
                        print("incorrect del format or record", i, pos, line)
                        continue
                    del_dna = pos.split("del")[1]
                    if del_dna == seq[enst_pos - 1:enst_pos - 1 + len(del_dna)]:
                        seq_mut = seq[:enst_pos - 1] + seq[enst_pos - 1 + len(del_dna):]
                    else:
                        print("incorrect deletion, unmatched nucleotide", pos)

                elif vartype == "INS":
                    try:
                        enst_pos = int(re.findall(r'\d+', cdna_pos.split("_")[0])[0])
                    except IndexError:
                        print("incorrect ins/dup format or record", i, pos, line)
                        continue
                    if "ins" in pos:
                        ins_dna = pos.split("ins")[1]
                    elif "dup" in pos:
                        ins_dna = pos.split("dup")[1]
                        if len(ins_dna) > 1:
                            enst_pos = int(re.findall(r'\d+', cdna_pos.split("_")[1])[0])
                    else:
                        print("unexpected insertion format")
                        continue

                    seq_mut = seq[:enst_pos] + ins_dna + seq[enst_pos:]

                if seq_mut == "":
                    continue

                mut_pro_seq = seq_mut.translate(to_stop=False)
                if len(mut_pro_seq) > 6:
                    header = "cbiomut:%s:%s:%s:%s" % (enst, gene, aa_mut, varclass)
                    output.write(">%s\n%s\n" % (header, mut_pro_seq))

                    if self._split_by_filter_column:
                        try:
                            group_mutations_dict[group][header] = mut_pro_seq
                        except KeyError:
                            group_mutations_dict[group] = {header: mut_pro_seq}

        for group in group_mutations_dict.keys():
            with open(self._local_output_file.replace('.fa', '') + '_' + regex.sub('', group) + '.fa', 'w') as fn:
                for header in group_mutations_dict[group].keys():
                    fn.write(">{}\n{}\n".format(header, group_mutations_dict[group][header]))
