from ..test import ExperimentTest, Results
from . import Snapshot


class HasCorrectItems(ExperimentTest):
    """Passes if an ASHS resource is found and such resource contains the main
    expected items."""

    passing = 'BBRCDEV_E02443',
    failing = 'BBRCDEV_E00385',

    def run(self, experiment_id):
        from fnmatch import fnmatch

        expected_items = ['*_left_corr_usegray_volumes.txt',
                          '*_right_heur_volumes.txt',
                          '*_left_heur_volumes.txt',
                          '*_icv.txt',
                          '*_right_corr_nogray_volumes.txt',
                          '*_right_lfseg_heur.nii.gz',
                          '*_left_lfseg_heur.nii.gz',
                          '*_right_raw_volumes.txt',
                          '*_right_lfseg_corr_usegray.nii.gz',
                          '*_right_corr_usegray_volumes.txt',
                          '*_left_lfseg_corr_nogray.nii.gz',
                          '*_left_corr_nogray_volumes.txt',
                          '*_left_lfseg_corr_usegray.nii.gz',
                          '*_left_raw_volumes.txt',
                          '*_right_lfseg_corr_nogray.nii.gz']

        res = self.xnat_instance.select.experiment(experiment_id)\
            .resource('ASHS')
        files = list(res.files())
        file_list = set([e._urn for e in files])
        for e in expected_items:
            if not [f for f in file_list if fnmatch(f, e)]:
                return Results(False,
                               data=['ASHS %s matching item not found.' % e])

        return Results(True, data=[])


class HasCorrectASHSVersion(ExperimentTest):
    """This test checks the version of ASHS used for processing the images. Passes
    if ASHS outputs were created using the expected version (`ASHS version 1.0.0
    Release date: 20170915`)."""

    passing = 'BBRCDEV_E02443',
    failing = 'BBRCDEV_E00385',

    def run(self, experiment_id):

        expected_version = 'ASHS version 1.0.0 Release date: 20170915'

        data = self.xnat_instance.array.mrsessions(experiment_id=experiment_id,
                                                   columns=['label']
                                                   ).data
        labels = ['label', 'project', 'xnat:mrsessiondata/subject_id']
        exp_label, project, subject_id = [data[0][e] for e in labels]

        res = self.xnat_instance.select.project(project).subject(subject_id)\
            .experiment(experiment_id).resource('ASHS')

        if res.file('LOGS/%s.log' % exp_label).exists():
            log = res.file('LOGS/%s.log' % exp_label)
        else:
            return Results(False, data=['ASHS log file not found.'])

        log_data = self.xnat_instance.get(log.attributes()['URI']).text
        ants_version = [line for line in log_data.splitlines()
                        if line.startswith('ASHS version ')]

        if not ants_version or not ants_version[0].startswith(expected_version):
            return Results(False, data=['%s' % ants_version[0]])

        return Results(True, data=[])


class ASHSSnapshot(ExperimentTest, Snapshot):
    """This test creates a snapshot of the results generated by ASHS.
    Passes if the snapshot is created successfully. Fails otherwise. Does not
    tell anything on the segmentation quality."""

    passing = 'BBRCDEV_E02443',
    failing = 'BBRCDEV_E00754',  # has no ASHS resource
    resource_name = 'ASHS'
    axes = 'x'
    rowsize = 5
    figsize = (19, 10)
    step = 1
    threshold = 0
    n_slices = {'x': 20}
    labels = None

    def run(self, experiment_id):
        import os
        if os.getenv('SKIP_SNAPSHOTS_TESTS') == 'True':
            return Results(experiment_id == self.passing[0],
                           data=['Skipping it. (SKIP_SNAPSHOTS_TESTS)'])
        try:
            snap_fp = self.snap(experiment_id)
        except Exception:
            return Results(False, data=['Snapshot creation failed.'])
        return Results(True, data=[snap_fp])

    def report(self):
        report = []
        if self.results.has_passed:
            path = self.results.data[0]
            report.append('![snapshot](%s)' % path)
        else:
            report = self.results.data

        return report


class HasNormalSubfieldVolumes(ExperimentTest):
    """This test compares subfield volumes with predefined boundaries. Passes if
    the number of outliers is strictly lower than 4. Fails otherwise."""

    passing = 'BBRCDEV_E02443',
    failing = 'BBRCDEV_E02824',

    def run(self, experiment_id):
        threshold = 4

        import logging as log
        boundaries = {'ERC': [316.10866, 701.6651200000001],
                      'CA3': [23.980800000000002, 105.48840000000034],
                      'SUB': [324.01752, 616.3894799999999],
                      'DG': [504.71943, 1140.2571200000004],
                      'sulcus': [152.89474, 714.1092500000002],
                      'CA1': [921.48378, 1827.1949599999998],
                      'CA2': [5.421, 30.295],
                      'BA36': [1011.9748099999999, 2591.5785499999997],
                      'misc': [34.057199999999995, 377.566],
                      'BA35': [313.9794, 741.5731199999998],
                      'PHC': [647.80592, 1405.6312799999992]}

        r = self.xnat_instance.select.experiment(experiment_id).resource('ASHS')
        if not r.exists():
            log.error('ASHS resource not found for %s' % (experiment_id))
            return Results(False, data=['ASHS resource not found'])

        try:
            df = r.volumes()
        except IndexError:
            return Results(False, data=['ASHS volumes files missing.'])

        outliers = []
        missing = []
        for s in ['left', 'right']:
            ed = df.query('side == "%s"' % s)
            for r in boundaries.keys():
                reg = ed.query('region =="%s"' % r)['volume'].tolist()
                if len(reg) != 0:
                    reg = reg[0]
                    if reg < boundaries[r][0] or reg > boundaries[r][1]:
                        res = '%s_%s %s (%s)' % (r, s, reg, boundaries[r])
                        outliers.append(res)
                else:
                    missing.append('%s_%s' % (r, s))
        res = len(outliers) < threshold
        return Results(res, data=[outliers])


class HasAllSubfields(ExperimentTest):
    """Passes if all hippocampal subfields are found in the resulting segmentation.
    Fails otherwise."""

    passing = 'BBRCDEV_E02443',
    failing = 'BBRCDEV_E02823',

    def run(self, experiment_id):
        import logging as log

        regions = ['ERC', 'CA3', 'SUB', 'DG', 'sulcus', 'CA1',
                   'CA2', 'BA36', 'misc', 'BA35', 'PHC']

        r = self.xnat_instance.select.experiment(experiment_id).resource('ASHS')
        if not r.exists():
            log.error('ASHS resource not found for %s' % experiment_id)
            return Results(False, data=['No ASHS resource found.'])

        try:
            volumes = r.volumes()
        except IndexError:
            return Results(False, data=['ASHS volumes files missing.'])

        ed = volumes.query('side != "left" & side != "right"')
        reg = ed['region'].tolist()
        if not (len(reg) == 1 and reg[0] == 'tiv'):
            return Results(False, data=['TIV missing'])

        for s in ['left', 'right']:
            ed = volumes.query('side == "%s"' % s)
            reg = ed['region'].tolist()
            difference = set(regions).difference(reg)
            if len(reg) != len(regions):
                return Results(False,
                               data=['Missing subfields (%s)' % str(difference)])

        return Results(True, data=[''])


class AreCAVolumesConsistent(ExperimentTest):
    """Checks that CA1, CA2 and CA3 (_Cornu Ammonis_ areas) display the  
    expected order in volumes in the resulting segmentation.
    Passes if `CA1` > `CA3` > `CA2` and fails otherwise."""

    passing = 'BBRCDEV_E02443',
    failing = 'BBRCDEV_E02803',

    def run(self, experiment_id):

        regions = ['CA1', 'CA3', 'CA2']

        res = []
        r = self.xnat_instance.select.experiment(experiment_id).resource('ASHS')
        if not r.exists():
            return Results(False, data=['No ASHS resource found.'])

        try:
            hippo_vols = r.volumes()
        except IndexError:
            return Results(False, data=['ASHS volumes files missing.'])

        ca_vols = hippo_vols[hippo_vols['region'].isin(regions)].\
            sort_values(by=['volume'], ascending=[False])

        for hemis in ['left', 'right']:
            sorted_ca_vols = ca_vols[ca_vols['side'] == hemis]['region'].to_list()

            if sorted_ca_vols != regions:
                msg = '%s: %s' % (hemis, ' > '.join(sorted_ca_vols))
                res.append(msg)

        if res:
            return Results(False, data=['Inconsistent CA volume sizes (%s)'
                                        % ",".join(res)])
        else:
            return Results(True, data=[])


class HaveRawImagesValidIntensityRange(ExperimentTest):
    """ASHS pipeline casts data from both input images -`MPRAGE` and `TSE`- to
    'short' (16 bit) datatype. For input images with large intensity value ranges
    (after scaling linear transformations are applied) that might lead to
    erroneous or incomplete results. This test checks that the range of intensities
    of ASHS input images (i.e. files `mprage_raw.nii` and `tse_raw.nii`) do not
    exceed the maximum number of signed values represented with 16 bit (`2^15`).
    Passes if ASHS raw input images have a range of intensity values smaller
    than 32768. Fails otherwise."""

    passing = 'BBRCDEV_E02803',
    failing = 'BBRCDEV_E02802',

    def run(self, experiment_id):
        import os
        import tempfile
        import nibabel as nib
        from nibabel.volumeutils import finite_range

        res = True
        data = []
        r = self.xnat_instance.select.experiment(experiment_id).resource('ASHS')
        if not r.exists():
            return Results(False,
                           data=['No ASHS resource found.'])

        fh, fp = tempfile.mkstemp(suffix='.nii.gz')
        os.close(fh)

        for fname in ['mprage_raw.nii.gz', 'tse_raw.nii.gz']:
            f = r.file(fname)
            if not f.exists():
                return Results(False,
                               data=['ASHS file `{}` not found.'.format(fname)])
            f.get(fp)
            img = nib.load(fp).dataobj
            min_val, max_val = finite_range(img)
            range_val = len(range(min_val, max_val))
            data.append([fname, str(range_val), str(min_val), str(max_val)])

            if range_val >= 2**15:
                res = False

        return Results(res, data=data)

    def report(self):
        report = []

        if len(self.results.data) > 1:
            data = []
            for fname, range_val, min_val, max_val in self.results.data:
                data.append('`{}`: {} ({} - {})'.format(fname, range_val,
                                                        min_val, max_val))
            report.append('; '.join(data))
        else:
            report.append(self.results.data[0])

        return report
