import pandas as pd

from genomics_data_index.api.query.GenomicsDataIndex import GenomicsDataIndex
from genomics_data_index.api.query.features.MutationFeaturesSummarizer import MutationFeaturesSummarizer
from genomics_data_index.storage.SampleSet import SampleSet
from genomics_data_index.storage.model.db import Sample
from genomics_data_index.test.integration import snippy_all_dataframes


def test_summary_all(loaded_database_genomic_data_store: GenomicsDataIndex):
    db = loaded_database_genomic_data_store.connection.database
    all_sample_ids = {s.id for s in db.get_session().query(Sample).all()}

    dfA = pd.read_csv(snippy_all_dataframes['SampleA'], sep='\t')
    dfB = pd.read_csv(snippy_all_dataframes['SampleB'], sep='\t')
    dfC = pd.read_csv(snippy_all_dataframes['SampleC'], sep='\t')
    expected_df = pd.concat([dfA, dfB, dfC])
    expected_df = expected_df.groupby('Mutation').agg({
        'Sequence': 'first',
        'Position': 'first',
        'Deletion': 'first',
        'Insertion': 'first',
        'Mutation': 'count',
    }).rename(columns={'Mutation': 'Count'}).sort_index()
    expected_df['Total'] = 9
    expected_df['Percent'] = 100 * (expected_df['Count'] / expected_df['Total'])

    present_set = SampleSet(all_sample_ids)
    mutations_summarizer = MutationFeaturesSummarizer(connection=loaded_database_genomic_data_store.connection,
                                                      ignore_annotations=True)

    mutations_df = mutations_summarizer.summary(present_set)
    mutations_df['Percent'] = mutations_df['Percent'].astype(int)  # Convert to int for easier comparison
    mutations_df = mutations_df.sort_index()

    assert len(expected_df) == len(mutations_df)
    assert list(expected_df.columns) == list(mutations_df.columns)
    assert list(expected_df.index) == list(mutations_df.index)
    assert list(expected_df['Count']) == list(mutations_df['Count'])
    assert list(expected_df['Total']) == list(mutations_df['Total'])
    assert 22 == mutations_df.loc['reference:619:G:C', 'Percent']


def test_summary_unique(loaded_database_genomic_data_store: GenomicsDataIndex):
    db = loaded_database_genomic_data_store.connection.database
    sampleA = db.get_session().query(Sample).filter(Sample.name == 'SampleA').one()
    sampleB = db.get_session().query(Sample).filter(Sample.name == 'SampleB').one()
    sampleC = db.get_session().query(Sample).filter(Sample.name == 'SampleC').one()
    all_sample_ids = {s.id for s in db.get_session().query(Sample).all()}

    mutations_summarizer = MutationFeaturesSummarizer(connection=loaded_database_genomic_data_store.connection,
                                                      ignore_annotations=True)

    dfA = pd.read_csv(snippy_all_dataframes['SampleA'], sep='\t')
    dfB = pd.read_csv(snippy_all_dataframes['SampleB'], sep='\t')
    dfC = pd.read_csv(snippy_all_dataframes['SampleC'], sep='\t')

    # Unique to A
    present_set = SampleSet({sampleA.id})
    other_set = SampleSet(all_sample_ids - {sampleA.id})
    mutations_df = mutations_summarizer.unique_summary(present_set, other_set=other_set).sort_index()

    expected_df = dfA
    expected_df = expected_df.groupby('Mutation').agg({
        'Sequence': 'first',
        'Position': 'first',
        'Deletion': 'first',
        'Insertion': 'first',
        'Mutation': 'count',
    }).rename(columns={'Mutation': 'Count'}).sort_index()
    expected_df['Total'] = 1
    expected_df['Percent'] = 100 * (expected_df['Count'] / expected_df['Total'])

    mutations_df['Percent'] = mutations_df['Percent'].astype(int)  # Convert to int for easier comparison

    assert len(expected_df) == len(mutations_df)
    assert 46 == len(mutations_df)  # Check length against independently generated length
    assert list(expected_df.index) == list(mutations_df.index)
    assert list(expected_df['Count']) == list(mutations_df['Count'])
    assert list(expected_df['Total']) == list(mutations_df['Total'])
    assert 100 == mutations_df.loc['reference:3656:CATT:C', 'Percent']

    # Unique to B
    present_set = SampleSet({sampleB.id})
    other_set = SampleSet(all_sample_ids - {sampleB.id})
    mutations_df = mutations_summarizer.unique_summary(present_set, other_set=other_set).sort_index()

    dfAC = pd.concat([dfA, dfC])
    expected_df = dfB[~dfB['Mutation'].isin(list(dfAC['Mutation']))]
    expected_df = expected_df.groupby('Mutation').agg({
        'Sequence': 'first',
        'Position': 'first',
        'Deletion': 'first',
        'Insertion': 'first',
        'Mutation': 'count',
    }).rename(columns={'Mutation': 'Count'}).sort_index()
    expected_df['Total'] = 1
    expected_df['Percent'] = 100 * (expected_df['Count'] / expected_df['Total'])

    mutations_df['Percent'] = mutations_df['Percent'].astype(int)  # Convert to int for easier comparison

    assert len(expected_df) == len(mutations_df)
    assert list(expected_df.index) == list(mutations_df.index)
    assert list(expected_df['Count']) == list(mutations_df['Count'])
    assert list(expected_df['Total']) == list(mutations_df['Total'])
    assert 100 == mutations_df.loc['reference:349:AAGT:A', 'Percent']

    # Unique to BC
    present_set = SampleSet({sampleB.id, sampleC.id})
    other_set = SampleSet(all_sample_ids - {sampleB.id, sampleC.id})
    mutations_df = mutations_summarizer.unique_summary(present_set, other_set=other_set).sort_index()

    dfBC = pd.concat([dfB, dfC])
    expected_df = dfBC[~dfBC['Mutation'].isin(list(dfA['Mutation']))]
    expected_df = expected_df.groupby('Mutation').agg({
        'Sequence': 'first',
        'Position': 'first',
        'Deletion': 'first',
        'Insertion': 'first',
        'Mutation': 'count',
    }).rename(columns={'Mutation': 'Count'}).sort_index()
    expected_df['Total'] = 2
    expected_df['Percent'] = 100 * (expected_df['Count'] / expected_df['Total'])

    mutations_df['Percent'] = mutations_df['Percent'].astype(int)  # Convert to int for easier comparison

    assert len(expected_df) == len(mutations_df)
    assert 66 == len(mutations_df)  # Check length against independently generated length
    assert list(expected_df.index) == list(mutations_df.index)
    assert list(expected_df['Count']) == list(mutations_df['Count'])
    assert list(expected_df['Total']) == list(mutations_df['Total'])
    assert 100 == mutations_df.loc['reference:619:G:C', 'Percent']
    assert 50 == mutations_df.loc['reference:866:GCCAGATCC:G', 'Percent']
    assert 50 == mutations_df.loc['reference:349:AAGT:A', 'Percent']


def test_summary_annotations(loaded_database_genomic_data_store_annotations: GenomicsDataIndex):
    db = loaded_database_genomic_data_store_annotations.connection.database

    mutations_summarizer = MutationFeaturesSummarizer(
        connection=loaded_database_genomic_data_store_annotations.connection,
        ignore_annotations=False)

    sample_sh14_001 = db.get_session().query(Sample).filter(Sample.name == 'SH14-001').one()
    sample_sh14_014 = db.get_session().query(Sample).filter(Sample.name == 'SH14-014').one()
    sample_sh10_014 = db.get_session().query(Sample).filter(Sample.name == 'SH10-014').one()
    three_samples = {sample_sh14_001.id, sample_sh14_014.id, sample_sh10_014.id}

    present_set = SampleSet(three_samples)
    mutations_df = mutations_summarizer.summary(present_set)

    assert ['Sequence', 'Position', 'Deletion', 'Insertion',
            'Count', 'Total', 'Percent', 'Annotation', 'Annotation_Impact',
            'Gene_Name', 'Gene_ID', 'Feature_Type', 'Transcript_BioType',
            'HGVS.c', 'HGVS.p', 'ID_HGVS.c', 'ID_HGVS.p', 'ID_HGVS_GN.c', 'ID_HGVS_GN.p'] == list(mutations_df.columns)
    assert 177 == len(mutations_df)
    mutations_df['Percent'] = mutations_df['Percent'].astype(int)  # easier to compare percents in assert

    # missense variant (3/3)
    assert ['NC_011083', 140658, 'C', 'A', 3, 3, 100,
            'missense_variant', 'MODERATE', 'murF', 'SEHA_RS01180', 'transcript', 'protein_coding',
            'c.497C>A', 'p.Ala166Glu',
            'hgvs:NC_011083:SEHA_RS01180:c.497C>A', 'hgvs:NC_011083:SEHA_RS01180:p.Ala166Glu',
            'hgvs_gn:NC_011083:murF:c.497C>A', 'hgvs_gn:NC_011083:murF:p.Ala166Glu'] == list(
        mutations_df.loc['NC_011083:140658:C:A'])

    # Intergenic variant (1/3)
    assert ['NC_011083', 4555461, 'T', 'TC', 1, 3, 33,
            'intergenic_region', 'MODIFIER', 'SEHA_RS22510-SEHA_RS26685', 'SEHA_RS22510-SEHA_RS26685',
            'intergenic_region', 'NA',
            'n.4555461_4555462insC', 'NA',
            'hgvs:NC_011083:n.4555461_4555462insC', 'NA',
            'hgvs_gn:NC_011083:n.4555461_4555462insC', 'NA'] == list(
        mutations_df.loc['NC_011083:4555461:T:TC'].fillna('NA'))
