"""
Toxopy (https://github.com/bchaselab/Toxopy)
© M. Alyetama, University of Nebraska at Omaha
Licensed under the terms of the MIT license
"""

from scipy.stats import mannwhitneyu
import pandas as pd
from toxopy import trials, nadlc, roi_behaviors
from itertools import combinations


def MannWhitney_U(csv_file, test, drop_non_dlc=False):

    df = pd.read_csv(csv_file)

    excluded_cats, trls = nadlc(), trials()

    if drop_non_dlc is True:
        for c in excluded_cats:
            df.drop(df[df.cat == c].index, inplace=True)

    alpha = 0.05

    def statVal(stat, p):
        return 'Statistics=%.3f, p=%.3f' % (stat, p)

    def alphaTest(p):
        if p > alpha:
            result = 'fail to reject H0'
        else:
            result = 'reject H0'
        return result


    if test == 'time_budget':
        def time_budget_mw(csv_file, only_sig=False):
            """
            << TIME BUDGET >>
            The time budget of a cat for a trial was determined by summing the times it spent on the individual behaviors assigned to each
            behavioral category (Affiliative, Calm, Exploration/Locomotion,
            and Fear). 
            """
            behaviors = ['Exploration/locomotion',
                         'Fear', 'Calm', 'Affiliative']

            def calc_mw(t, b):
                def slct(s, t, b):
                    return list(df.loc[(df['infection_status'] == s)
                                       & (df['trial'] == t) &
                                       ((df['Behavior'] == b))]['value'])

                neg, pos = slct('Negative', t, b), slct('Positive', t, b)
                stat, p = mannwhitneyu(neg, pos)
                stat_values = statVal(stat, p)

                if only_sig is False:
                    result = alphaTest(p)
                    return stat_values, result
                else:
                    if p < alpha:
                        result = 'reject H0'
                        return stat_values, result

            for t in trls:
                print(f'{"-" * 60}\n{t}')
                for b in behaviors:
                    mw_res = calc_mw(t, b)
                    res = f'{b} ==> {mw_res}'
                    if mw_res is not None:
                        print(res)


    elif test == 'latency':
        def latency_mw(csv_file):
            """
            << The Latency Test >>
            Scores on a single behavior “latency to exit the carrier” were
            used to compute results for the latency test. 
            """
            def slct(status):
                return list(df.loc[(df['infection_status'] == status)]['t1_latency_to_exit_carrier'])

            neg, pos = slct('Negative'), slct('Positive')
            stat, p = mannwhitneyu(neg, pos)
            stat_values = statVal(stat, p)
            result = alphaTest(p)

            print(f'Latency to exit the carrier ==> {stat_values}, {result}')


    elif test == 'roi':
        def roi_mw(csv_file):
            """
            << Time Spent in Regions of Interest (ROIs) >>
            Video pixel coordinates for the DeepLabCut-generated labels were
            used to calculate the average time a cat spent near the walls (as opposed to being in the center) in the experimental room. 
            """
            vois = ["cumulative_time_in_roi_sec",
                    "avg_time_in_roi_sec", "avg_vel_in_roi"]

            def slct(i, s, j):
                return df.loc[(df['trial'] == i) & (df['infection_status'] == s) & (df['ROI_name'] == j)]

            for j in ['walls', 'middle']:
                print(f'\n{j}')
                for i in trls:
                    pos, neg = slct(i, 'Positive', j), slct(i, 'Negative', j)
                    print(f'\n{i}')
                    for voi in vois:
                        stat, p = mannwhitneyu(neg[voi], pos[voi])
                        print(f'{voi} ==> {statVal(stat, p)}, {alphaTest(p)}')


    elif test == 'sniffing_vial':
        def sniff_mw(csv_file):
            """
            << Sniffing Treatment Vial >>
            a single behavior “sniff treatment vial” was used in the “Sniffing Treatment Vial” test. 
            """
            behaviors = ['t3_sniffsaline', 't5_sniffurine',
                         't7_sniffsaline', 't9_sniffurine']

            def slct(s, i):
                return list(df.loc[(df['infection_status'] == s)][i])

            for i in behaviors:
                neg. pos = slct('Negative', i), slct('Positive', i)
                stat, p = mannwhitneyu(neg, pos)
                stat_values = statVal(stat, p)
                result = alphaTest(p)
                print(f'{i} ==> {stat_values}, {result}')


    elif test == 'compare_roi_within':
        """
        << Time Spent in ROIs – Within-group >>
        Similar to 'roi'. Except it compares time spent in ROI *within* group between trials.
        """
        def roi_diff_Btrials_Wgroup(csv_file, comparison, trial_type=None, export_csv=False):

            def slct(tr):
                if comparison == 'all':
                    return df.loc[(df['ROI_name'] == 'walls')
                              & (df['infection_status'] == k) &
                              (df['trial'] == tr)][b]
                elif comparison == 'split':
                    return df.loc[(df['ROI_name'] == 'walls') & (df['infection_status'] == k) & df.trial.isin(tr)][b]

            def res(s):
                if comparison == 'all':
                    pt_comp = f'{combs[i][0]} vs {combs[i][1]}'
                elif comparison == 'split':
                    pt_comp = '1st-half vs 2nd-half'
                return f'{pt_comp}{s}{stat}{s}{round(p, 4)}{s}{result}'

            def ttype(rng):
                return list(combinations(rng, 2))

            trls, behaviors = trials(), roi_behaviors()

            if trial_type == 'treatment':
                combs, r = ttype(trls[::2][1:]), range(0, 6)
            elif trial_type == 'CA':
                combs, r = ttype(trls[1::2]), range(0, 10)
            else:
                r = range(0, 1)

            if export_csv is True:
                f = open('results.csv', 'w')
                print('status,comparison,stat,p,interpretation', file=f)

            for k in ['Negative', 'Positive']:
                if export_csv is not True:
                    print(f'{"-" * 65}\n<< {k} >>')
                for b in behaviors:
                    if export_csv is not True:
                        print(f'\n#{b}')
                    for i in r:
                        if comparison == 'all':
                            c1, c2 = slct(combs[i][0]), slct(combs[i][1])
                        elif comparison == 'split':
                            c1, c2 = slct(trls[0:5]), slct(trls[5:10])
                        stat, p = mannwhitneyu(c1, c2)
                        result = alphaTest(p)
                        if export_csv is not True:
                            print(res(' ==> '))
                        else:
                            print(f'{k},', res(','), file=f)

            if export_csv is True:
                f.close()
