import pandas as pd

def junction_summary_stats(junction_spanning_files, cds_files, output_csv_name):

    enzymelist = [] #all enzymes across all input files (contains duplicates)
    total_pept = [] #total no. of peptides for each enzyme in each input file
    junctions = [] # no. of introns per enzyme, per input file
    summary_enzymelist = [] # all enzymes
    summary_junctions = [] # no. of junctions per enzyme summarised for all input files
    summary_peptides = [] # no. of peptides per enzyme summarised for all input files
    junctions_covered = [] # no of junctions covered (duplicates removed to avoid double counting)
    summary_junctions_covered = [] # no of junctions covered per enzyme summarised for all input files

    for input_file in junction_spanning_files:
        junctions_file = pd.read_csv(input_file)

        # Number of junction covering peptides generated by each enzyme
        junction_statistics = pd.DataFrame()

        for enzyme in sorted(set(junctions_file['enzyme'])):
            enzymecheck = junctions_file.apply(lambda x: True if x['enzyme'] == enzyme else False, axis=1)
            enzyme_total_pept = len(enzymecheck[enzymecheck == True])
            enzymelist.append(enzyme)
            total_pept.append(enzyme_total_pept)

        junction_statistics['enzyme'] = enzymelist
        junction_statistics['junction_spanning_peptides'] = total_pept

        ##################################################
        # Unique Junctions covered by each separate enzyme
        indexed_file = junctions_file.set_index('enzyme')

        for enzyme in sorted(set(enzymelist)):
            intronslist = []
            if enzyme in indexed_file.index:
                by_enzyme = indexed_file.loc[[enzyme]]
                intronslist.append(by_enzyme['intron_id'].tolist())
                for introns in intronslist:
                    junctions.append(len(set(introns)))

        junction_statistics['unique_junctions_covered'] = junctions
        junction_statistics = junction_statistics.set_index('enzyme')

        #########################################
        # Total junctions covered by each enzyme
        for enzyme in sorted(set(enzymelist)):
            nonredundant_count = 0
            if enzyme in indexed_file.index:
                by_enzyme = indexed_file.loc[[enzyme]]
                for intron_id in set(by_enzyme['intron_id']):
                    by_intron_id = by_enzyme.loc[(by_enzyme.intron_id == intron_id)]
                    without_duplicate_parents = by_intron_id.drop_duplicates(subset='parent', keep='first')
                    nonredundant_count += len(without_duplicate_parents)
                # print(enzyme, nonredundant_count, input_file)
                junctions_covered.append(nonredundant_count)

        junction_statistics['total_junctions_covered'] = junctions_covered

    # SIMPLIFIED DATAFRAME
    for enzyme in sorted(set(junction_statistics.index)):
        by_enzyme = junction_statistics.loc[[enzyme]]  # needs to be indexed by enzyme
        summary_enzymelist.append(enzyme)
        summary_junctions.append(sum(by_enzyme['unique_junctions_covered']))
        summary_peptides.append(sum(by_enzyme['junction_spanning_peptides']))
        summary_junctions_covered.append(sum(by_enzyme['total_junctions_covered']))

    output_df = pd.DataFrame()
    output_df['enzyme'] = summary_enzymelist
    output_df['junction_spanning_peptides'] = summary_peptides
    output_df['unique_junctions_covered'] = summary_junctions
    output_df['total_junctions_covered'] = summary_junctions_covered

    ##############
    intron_no = 0
    for input_name in cds_files:
        # Real number of exon-exon junctions in the proteins being examined
        cdsdf = pd.read_csv(input_name)
        cdsdf = cdsdf[cdsdf.intron_id != 'na']
        intron_no += len((cdsdf['intron_id']))  # is set necessary or not?

        # Adding column for % of junctions covered, out of the available
        # junctions in the proteins
        output_df['total_junction_coverage'] = output_df['total_junctions_covered'] / (intron_no) * 100
        output_df.to_csv(output_csv_name, index=False)

## EXAMPLE
# add all of the filtered junction csv files to this list
junction_spanning_files = ['chr1_+_junction_spanning.csv', 'chr1_-_junction_spanning.csv',
                           'chr2_+_junction_spanning.csv', 'chr2_-_junction_spanning.csv',
                           'chr3_+_junction_spanning.csv', 'chr3_-_junction_spanning.csv',
                           'chr4_+_junction_spanning.csv', 'chr4_-_junction_spanning.csv',
                           'chr5_+_junction_spanning.csv', 'chr5_-_junction_spanning.csv']
cds_files = ['chr1_+_cdsdf.csv', 'chr1_-_cdsdf.csv',
             'chr2_+_cdsdf.csv', 'chr2_-_cdsdf.csv',
             'chr3_+_cdsdf.csv', 'chr3_-_cdsdf.csv',
             'chr4_+_cdsdf.csv', 'chr4_-_cdsdf.csv',
             'chr5_+_cdsdf.csv', 'chr5_-_cdsdf.csv']

#junction_summary_stats(junction_spanning_files, cds_files, 'junction_statistics.csv')
