# **************************************************************************
# *
# * Authors:     Estrella Fernandez Gimenez (me.fernandez@cnb.csic.es)
# *
# *  BCU, Centro Nacional de Biotecnologia, CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************

from os.path import abspath, basename

from pyworkflow.protocol.params import PathParam
from pyworkflow.utils.path import createAbsLink, copyFile
from pwem.emlib.image import ImageHandler
from pwem.objects import Transform

from tomo.protocols.protocol_base import ProtTomoImportFiles
from tomo.objects import SubTomogram
from tomo.utils import _getUniqueFileName

from ..convert import readDynTable


class DynamoImportSubtomos(ProtTomoImportFiles):
    """ This protocol imports subtomograms with metadata generated by a Dynamo table.
    A Dynamo catalogue can be also imported in order to relate subtomograms with their original tomograms. """

    _label = 'import subtomos from Dynamo'

    # -------------------------- DEFINE param functions ----------------------
    def _defineParams(self, form):
        ProtTomoImportFiles._defineParams(self, form)
        form.addParam('tablePath', PathParam,
                      label="Dynamo table file:",
                      help='Select dynamo table (.tbl) to link dynamo metadata to the subtomograms that will be '
                           'imported to Scipion. ')
        form.addParam('ctgPath', PathParam,
                      label="Dynamo catalogue file:", allowsNull=True,
                      help='Select dynamo catalogue (.ctg) to link original tomograms to the subtomograms that will be '
                           'imported to Scipion. ')

    # --------------------------- STEPS functions ------------------------------
    def _insertAllSteps(self):
        self._insertFunctionStep('importSubTomogramsStep', self.getPattern(), self.samplingRate.get())
        self._insertFunctionStep('createOutputStep')

    # --------------------------- STEPS functions -----------------------------
    def importSubTomogramsStep(self, pattern, samplingRate):
        self.info("Using pattern: '%s'" % pattern)
        subtomo = SubTomogram()
        subtomo.setSamplingRate(samplingRate)
        if not self.ctgPath:
            ctlg = self._getExtraPath('dynamo_catalogue.vll')
            copyFile(self.ctgPath.get(), ctlg)
            fhCtlg = open(ctlg, 'r')
            self.tomoDict = {}
            next(fhCtlg)
            for i, line in enumerate(fhCtlg):
                tomoId = i+1
                tomoName = line.rstrip()
                self.tomoDict[tomoId] = tomoName
            fhCtlg.close()
        imgh = ImageHandler()
        self.subtomoSet = self._createSetOfSubTomograms()
        self.subtomoSet.setSamplingRate(samplingRate)
        dynTable = self._getExtraPath('dynamo_table.tbl')
        copyFile(self.tablePath.get(), dynTable)
        self.fhTable = open(dynTable, 'r')
        for fileName, fileId in self.iterFiles():
            x, y, z, n = imgh.getDimensions(fileName)
            if fileName.endswith('.mrc') or fileName.endswith('.map'):
                fileName += ':mrc'
                if z == 1 and n != 1:
                    zDim = n
                    n = 1
                else:
                    zDim = z
            else:
                zDim = z
            origin = Transform()
            origin.setShifts(x / -2. * samplingRate, y / -2. * samplingRate, zDim / -2. * samplingRate)
            subtomo.setOrigin(origin)
            newFileName = _getUniqueFileName(self.getPattern(), fileName)
            # newFileName = abspath(self._getVolumeFileName(fileName))
            if fileName.endswith(':mrc'):
                fileName = fileName[:-4]
            createAbsLink(fileName, newFileName)
            if n == 1:
                self._addSubtomogram(subtomo, fileName, newFileName)
            else:
                for index in range(1, n + 1):
                    self._addSubtomogram(subtomo, newFileName, index=index)
        self.fhTable.close()

    def _addSubtomogram(self, subtomo, newFileName, index=None):
        """ adds a subtomogram to a set """
        subtomo.cleanObjId()
        if index is None:
            subtomo.setFileName(newFileName)
        else:
            subtomo.setLocation(index, newFileName)
        readDynTable(self, subtomo)
        if not self.ctgPath:
            scipionTomoName = self.tomoDict.get(subtomo.getVolId())
            subtomo.setVolName(scipionTomoName)
            subtomo.getCoordinate3D().setVolName(scipionTomoName)
        self.subtomoSet.append(subtomo)

    def createOutputStep(self):
        self._defineOutputs(outputSubTomograms=self.subtomoSet)

    # --------------------------- INFO functions ------------------------------
    def _hasOutput(self):
        return self.hasAttribute('outputSubTomograms')

    def _getSubTomMessage(self):
        return "SubTomograms %s" % self.getObjectTag('outputSubTomograms')

    def _summary(self):
        summary = []
        if self._hasOutput():
            summary.append("%s imported from:\n%s"
                           % (self._getSubTomMessage(), self.getPattern()))
            if self.samplingRate.get():
                summary.append(u"Sampling rate: *%0.2f* (Å/px)" %
                               self.samplingRate.get())
        return summary

    def _methods(self):
        methods = []
        if self._hasOutput():
            methods.append(" %s imported with a sampling rate *%0.2f*" %
                           (self._getSubTomMessage(), self.samplingRate.get()))
        return methods

    def _getVolumeFileName(self, fileName, extension=None):
        if extension is not None:
            baseFileName = "import_" + str(basename(fileName)).split(".")[0] + ".%s" % extension
        else:
            baseFileName = "import_" + str(basename(fileName)).split(":")[0]
        return self._getExtraPath(baseFileName)
