from ..test import ExperimentTest, Results


class HasCorrectItems(ExperimentTest):
    """Passes if a `BAMOS` resource is found and this resource
    has the expected items according to the pipeline
    [specifications](https://gitlab.com/bbrc/xnat/docker-images/-/tree/master/bamos#outputs)."""

    passing = 'BBRCDEV_E02823',
    failing = 'BBRCDEV_E02939',

    def run(self, experiment_id):
        resource_name = 'BAMOS'
        expected_items = ['CorrectLesion_*',
                          'Layers_*',
                          'Lobes_*',
                          'FLAIR_*',
                          'Mask_*',
                          'ScriptBaMoS_*',
                          'Aff_FLAIRtoT1.txt',
                          'TxtLesion*1Lesion_*',
                          'Connect_*1Lesion_*',
                          'LesionMahal_*',
                          'GIF_Segmentation_*',
                          'GIF_Parcellation_*']

        result = True

        e = self.xnat_instance.select.experiment(experiment_id)
        res = e.resource(resource_name)

        missing = []
        for item in expected_items:
            files = res.files(item).get()
            if not files:
                missing.append(item)

        if missing:
            result = False

        return Results(result, data=missing)

    def report(self):
        report = []
        if not self.results.has_passed:
            report.append('Missing items: {}.'
                          .format(self.results.data).replace('\'', '`'))
        return report


class FLAIRCoregistrationSnapshot(ExperimentTest):
    """This test generates a snapshot depicting the accuracy of the coregistration
    of the FLAIR image to the T1w image space, generated by the `BAMOS` pipeline.
    Snapshot consists of a T1w image as background (grayscale) and a coregistered
    FLAIR image as additional overlay  (red colormap). Test passes if the snapshot
    is created successfully, fails otherwise. Does not tell anything on the 
    quality of the registration procedure."""

    passing = 'BBRCDEV_E02823',
    failing = 'BBRCDEV_E02939',  # has no BAMOS resource

    def run(self, experiment_id):
        import os
        import tempfile
        from . import ants_snapshot

        resource_name = 'BAMOS'

        if os.getenv('SKIP_SNAPSHOTS_TESTS') == 'True':
            return Results(experiment_id == self.passing[0],
                           data=['Skipping it. (SKIP_SNAPSHOTS_TESTS)'])

        e = self.xnat_instance.select.experiment(experiment_id)
        r = e.resource(resource_name)

        if not r.exists():
            msg = 'Resource {} not found.'.format(resource_name)
            return Results(False, data=[msg])

        filepaths = []
        # fetch and get the (usable) T1w scan image file
        fd, fp = tempfile.mkstemp(prefix='T1_', suffix='.nii.gz')
        os.close(fd)
        t1_scan = e.resource('BBRC_VALIDATOR').tests('ArchivingValidator',
                                                     key='HasUsableT1')['data'][0]
        t1_file = list(e.scan(t1_scan).resource('NIFTI').files('*.nii.gz'))[0]
        t1_file.get(fp)
        filepaths.append(fp)

        # get the coregistered FLAIR scan image from BAMOS outputs
        fd, fp = tempfile.mkstemp(prefix='FLAIR_', suffix='.nii.gz')
        os.close(fd)
        flair_fname = 'FLAIR_{}*.nii.gz'.format(e.label())
        try:
            flair_file = list(r.files(flair_fname))[0]
        except IndexError:
            return Results(False, data=['{} not found, check `HasCorrectItems` '
                                        'test results.'.format(flair_fname)])
        flair_file.get(fp)
        filepaths.append(fp)

        # generate an snapshot
        res = ants_snapshot(filepaths[0], filepaths[1])
        return Results(True, res)

    def report(self):
        report = []
        if self.results.has_passed:
            for path in self.results.data:
                report.append('![snapshot]({})'.format(path))
        else:
            report = self.results.data

        return report


class LesionSegmentationSnapshot(ExperimentTest):
    """This generates a snapshot of the WMH lesion segmentation results produced
    by `BAMOS`. Snapshot consists of a FLAIR background image with the segmentation
    overlaid. Test passes if the snapshot is created successfully, fails otherwise.
    Does not tell anything on the segmentation quality."""

    passing = 'BBRCDEV_E02823',
    failing = 'BBRCDEV_E02939',

    def run(self, experiment_id):
        import os
        import tempfile
        from nisnap import snap

        resource_name = 'BAMOS'
        axes = 'xz'
        slices = {'x': list(range(144, 250, 3)),
                  'z': list(range(70, 171, 4))}
        rowsize = {'x': 10, 'z': 6}
        opacity = 95

        if os.getenv('SKIP_SNAPSHOTS_TESTS') == 'True':
            return Results(experiment_id == self.passing[0],
                           data=['Skipping it. (SKIP_SNAPSHOTS_TESTS)'])
        try:
            bg, filepaths = self.__download_bamos__(experiment_id,
                                                    tempfile.gettempdir(),
                                                    resource_name)
            # Create snapshot via nisnap
            fd, snap_fp = tempfile.mkstemp(suffix=snap.__format__)
            os.close(fd)
            snap.plot_segment(filepaths, bg=bg, axes=axes, opacity=opacity,
                              slices=slices, rowsize=rowsize,
                              savefig=snap_fp, samebox=True)

        except Exception:
            return Results(False, data=['Snapshot creation failed.'])

        return Results(True, data=[snap_fp])

    def report(self):
        report = []
        if self.results.has_passed:
            path = self.results.data[0]
            report.append('![snapshot](%s)' % path)
        else:
            report = self.results.data

        return report

    def __download_bamos__(self, experiment_id, destination, resource_name='BAMOS'):
        import os.path as op

        filepaths = []

        e = self.xnat_instance.select.experiment(experiment_id)
        r = e.resource(resource_name)

        for each in ['FLAIR_{label}*.nii.gz', 'CorrectLesion_{label}*.nii.gz']:
            each = each.format(label=e.label())
            c = list(r.files(each))[0]
            fp = op.join(destination, c.label())
            c.get(fp)
            filepaths.append(fp)

        bg = filepaths.pop(0)
        return bg, filepaths


class LobesSegmentationSnapshot(ExperimentTest):
    """This generates a snapshot of the brain lobes segmentation results produced
    by `BAMOS`. Snapshot consists of a FLAIR background image with the segmented
    brain lobes overlaid. Test passes if the snapshot is created successfully,
    fails otherwise. Does not tell anything on the segmentation quality."""

    passing = 'BBRCDEV_E02823',
    failing = 'BBRCDEV_E02939',
    files = ['FLAIR_{label}*.nii.gz', 'Lobes_{label}*.nii.gz']

    def run(self, experiment_id):
        import os
        import tempfile
        from nisnap import snap

        resource_name = 'BAMOS'
        axes = 'xz'
        slices = {'x': list(range(144, 250, 3)),
                  'z': list(range(70, 171, 3))}
        rowsize = {'x': 10, 'z': 7}
        opacity = 65

        if os.getenv('SKIP_SNAPSHOTS_TESTS') == 'True':
            return Results(experiment_id == self.passing[0],
                           data=['Skipping it. (SKIP_SNAPSHOTS_TESTS)'])
        try:
            bg, filepaths = self.__download_bamos__(experiment_id,
                                                    tempfile.gettempdir(),
                                                    resource_name)
            # Create snapshot via nisnap
            fd, snap_fp = tempfile.mkstemp(suffix=snap.__format__)
            os.close(fd)
            snap.plot_segment(filepaths[0], bg=bg, axes=axes, opacity=opacity,
                              slices=slices, rowsize=rowsize,
                              savefig=snap_fp, samebox=True)

        except Exception:
            return Results(False, data=['Snapshot creation failed.'])

        return Results(True, data=[snap_fp])

    def report(self):
        report = []
        if self.results.has_passed:
            path = self.results.data[0]
            report.append('![snapshot](%s)' % path)
        else:
            report = self.results.data

        return report

    def __download_bamos__(self, experiment_id, destination, resource_name='BAMOS'):
        import os.path as op

        filepaths = []

        e = self.xnat_instance.select.experiment(experiment_id)
        r = e.resource(resource_name)
        for each in self.files:
            each = each.format(label=e.label())
            c = list(r.files(each))[0]
            fp = op.join(destination, c.label())
            c.get(fp)
            filepaths.append(fp)

        bg = filepaths.pop(0)
        return bg, filepaths


class LayersSegmentationSnapshot(LobesSegmentationSnapshot):
    __doc__ = LobesSegmentationSnapshot.__doc__
    __doc__ = __doc__.replace('lobes', 'layers')

    passing = 'BBRCDEV_E02823',
    failing = 'BBRCDEV_E02939',
    files = ['FLAIR_{label}*.nii.gz', 'Layers_{label}*.nii.gz']