from ..test import ExperimentTest, Results
from . import freesurfer


class HasCorrectItems(ExperimentTest):
    """Passes if a `DONSURF` resource is found and this resource
    has the expected items according to the pipeline
    [specifications](https://gitlab.com/bbrc/xnat/xnat-pipelines/-/tree/master/donsurf#outputs)."""

    passing = 'BBRCDEV_E02824',
    failing = 'BBRCDEV_E02830',

    def run(self, experiment_id):
        resource_name = 'DONSURF'
        expected_items = ['{id}_QC_reg.png',
                          'log_file.log',
                          'reg/register.dat',
                          'reg/register.dat.mincost',
                          'reg/register.dat.param',
                          'reg/register.dat.sum',
                          'reg/register.log',
                          'reg/register.lta',
                          'GM-MD-koo/dti_MD_koo.nii.gz',
                          'GM-MD-koo/pvf.nii.gz',
                          'GM-MD-koo/{id}.lh.fsaverage.fwhm0.GM-MD-koo.mgh',
                          'GM-MD-koo/{id}.lh.fsaverage.fwhm15.GM-MD-koo.mgh',
                          'GM-MD-koo/{id}.lh.nativesurf.GM-MD-koo.mgh',
                          'GM-MD-koo/{id}.rh.fsaverage.fwhm0.GM-MD-koo.mgh',
                          'GM-MD-koo/{id}.rh.fsaverage.fwhm15.GM-MD-koo.mgh',
                          'GM-MD-koo/{id}.rh.nativesurf.GM-MD-koo.mgh',
                          'GM-MD-koo/labels/lh.{id}_fsaverage.label',
                          'GM-MD-koo/labels/rh.{id}_fsaverage.label',
                          'SWM-MD/{id}.lh.fsaverage.fwhm0.SWM-MD.mgh',
                          'SWM-MD/{id}.lh.fsaverage.fwhm15.SWM-MD.mgh',
                          'SWM-MD/{id}.lh.nativesurf.SWM-MD.mgh',
                          'SWM-MD/{id}.rh.fsaverage.fwhm0.SWM-MD.mgh',
                          'SWM-MD/{id}.rh.fsaverage.fwhm15.SWM-MD.mgh',
                          'SWM-MD/{id}.rh.nativesurf.SWM-MD.mgh',
                          'SWM-MD/labels/lh.{id}_fsaverage.label',
                          'SWM-MD/labels/rh.{id}_fsaverage.label',
                          'GM-MD/{id}.lh.fsaverage.fwhm0.GM-MD.mgh',
                          'GM-MD/{id}.lh.fsaverage.fwhm15.GM-MD.mgh',
                          'GM-MD/{id}.lh.nativesurf.GM-MD.mgh',
                          'GM-MD/{id}.lh.nativesurf10000.GM-MD',
                          'GM-MD/{id}.rh.fsaverage.fwhm0.GM-MD.mgh',
                          'GM-MD/{id}.rh.fsaverage.fwhm15.GM-MD.mgh',
                          'GM-MD/{id}.rh.nativesurf.GM-MD.mgh',
                          'GM-MD/{id}.rh.nativesurf10000.GM-MD',
                          'GM-MD/labels/lh.{id}_fsaverage.label',
                          'GM-MD/labels/rh.{id}_fsaverage.label',
                          'stats/lh.aparc.stats',
                          'stats/rh.aparc.stats']
        expected_items = [i.format(id=experiment_id) for i in expected_items]
        result = True

        e = self.xnat_instance.select.experiment(experiment_id)
        res = e.resource(resource_name)

        missing = []
        for item in expected_items:
            files = res.files(item).get()
            if not files:
                missing.append(item)

        if missing:
            result = False

        return Results(result, data=missing)

    def report(self):
        report = []
        if not self.results.has_passed:
            report.append('Missing items: {}.'
                          .format(self.results.data).replace('\'', '`'))
        return report


class HasCorrectFreeSurferVersion(ExperimentTest):
    """This test checks the version of FreeSurfer used. Passes if DONSURF
    outputs were created using the expected version (`{version}`)."""

    passing = 'BBRCDEV_E02824',
    failing = 'BBRCDEV_E02830',
    resource_name = 'DONSURF'
    expected_version = 'freesurfer-linux-centos7_x86_64-7.1.1-20200723-8b40551'
    __doc__ = __doc__.format(version=expected_version)

    def run(self, experiment_id):
        e = self.xnat_instance.select.experiment(experiment_id)
        r = e.resource(self.resource_name)
        log = r.file('LOGS/{}.log'.format(e.label()))
        if not log.exists():
            msg = '{} log file not found.'.format(self.resource_name)
            return Results(False, data=[msg])

        log_data = self.xnat_instance.get(log._uri).text
        version = [line for line in log_data.splitlines()
                   if line.strip().startswith('freesurfer-')]

        if not version or version[0] != self.expected_version:
            return Results(False, data=['Incorrect FreeSurfer version: '
                                        '{}'.format(version[0])])

        return Results(True, data=[])


class DWIRegistrationSnapshot(ExperimentTest):
    """This test collects an snapshot of the accuracy of the registration step
    of the DWI image to the T1w image space, generated by the `DONSURF` pipeline.
    Snapshot consists of a background DWI image (grayscale) overlaid with contoured
    segmentations of both hemispheres (left=red, right=yellow).
    Test passes if the snapshot is successfully collected, fails otherwise. Does
    not tell anything on the quality of the registration procedure."""

    passing = 'BBRCDEV_E02824',
    failing = 'BBRCDEV_E02830',

    def run(self, experiment_id):
        import os
        import tempfile

        resource_name = 'DONSURF'

        if os.getenv('SKIP_SNAPSHOTS_TESTS') == 'True':
            return Results(experiment_id == self.passing[0],
                           data=['Skipping it. (SKIP_SNAPSHOTS_TESTS)'])

        e = self.xnat_instance.select.experiment(experiment_id)
        snap = e.resource(resource_name).file('{}_QC_reg.png'.format(e.id()))

        if not snap.exists():
            msg = '{} QC snapshot file not found.'.format(resource_name)
            return Results(False, data=[msg])

        # get the registration snapshot image
        fd, fp = tempfile.mkstemp(suffix='_QC_reg.png')
        os.close(fd)
        snap.get(fp)

        return Results(True, [fp])

    def report(self):
        report = []
        if self.results.has_passed:
            for path in self.results.data:
                report.append('![snapshot]({})'.format(path))
        else:
            report = self.results.data

        return report
