#!/usr/bin/env python3

# Copyright (C) 2018 Jasper Boom (jboom@infernum.nl)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License version 3 as 
# published by the Free Software Foundation.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

# Imports
import argparse
import shutil
import string
import random
import re
import os
import zipfile
import multiprocessing as mp
import pandas as pd
import subprocess as sp


# The removeWorkDirs function.
# This function removes all temporary working directories.
def removeWorkDirs(strDir):
    shutil.rmtree(strDir)


# The setOutputFiles function.
# This function creates the tabular and BLAST output files.
def setOutputFiles(strDir, strOutputFileName):
    dfOutput = pd.DataFrame(
        columns=["UMI ID", "UMI Seq", "Read Count", "Centroid Read"]
    )
    intCount = 0
    flBlast = strOutputFileName + "_BLAST.fasta"
    flTabular = strOutputFileName + "_TABULAR.tbl"
    for strFileName in os.listdir(strDir + "/postClustering"):
        strUmiNumber = strFileName.split("_")[0]
        strUmiString = strFileName.split("_")[1][:-6]
        intLineCount = 0
        with open(strDir + "/postClustering/" + strFileName) as oisClusterFile:
            for strLine in oisClusterFile:
                intLineCount += 1
        with open(strDir + "/postClustering/" + strFileName) as oisUmiFile:
            if intLineCount == 2:
                for strLine in oisUmiFile:
                    if strLine.startswith(">"):
                        strHeader = strLine.split("=")[1].strip("\n")
                        strRead = next(oisUmiFile)
                        dfOutput.loc[intCount] = [
                            strUmiNumber,
                            strUmiString,
                            strHeader.strip("\n"),
                            strRead.strip("\n").upper(),
                        ]
                        with open(flBlast, "a") as flOutput:
                            flOutput.write(">" + strUmiNumber + "\n")
                            flOutput.write(strRead.strip("\n").upper() + "\n")
                    else:
                        pass
            elif intLineCount > 2:
                intVersionCount = 1
                for strLine in oisUmiFile:
                    if strLine.startswith(">"):
                        strHeader = strLine.split("=")[1].strip("\n")
                        strRead = next(oisUmiFile)
                        strUmiVersion = (
                            strUmiNumber + "." + str(intVersionCount)
                        )
                        dfOutput.loc[intCount] = [
                            strUmiVersion,
                            strUmiString,
                            strHeader.strip("\n"),
                            strRead.strip("\n").upper(),
                        ]
                        with open(flBlast, "a") as flOutput:
                            flOutput.write(">" + strUmiVersion + "\n")
                            flOutput.write(strRead.strip("\n").upper() + "\n")
                        intVersionCount += 1
                        intCount += 1
                    else:
                        pass
            else:
                pass
        intCount += 1
    dfOutput = dfOutput.set_index("UMI ID")
    dfOutput.to_csv(flTabular, sep="\t", encoding="utf-8")


# The setClusterSize function.
# This function controls the Vsearch clustering. Every fasta file created by
# setSortBySize is clustered using Vsearch. The expected result is a single
# centroid sequence. This is checked in the setOutputFiles function.
def setClusterSize(strDir, strIdentity):
    for strFileName in os.listdir(strDir + "/clustering"):
        if strFileName.startswith("sorted"):
            strInputCommand = strDir + "/clustering/" + strFileName
            strOutputCommand = strDir + "/postClustering/" + strFileName[11:]
            rafClustering = sp.Popen(
                [
                    "vsearch",
                    "--cluster_size",
                    strInputCommand,
                    "--fasta_width",
                    "0",
                    "--id",
                    strIdentity,
                    "--sizein",
                    "--minseqlength",
                    "1",
                    "--centroids",
                    strOutputCommand,
                    "--sizeout",
                ],
                stdout=sp.PIPE,
                stderr=sp.PIPE,
            )
            strOut, strError = rafClustering.communicate()
        else:
            pass


# The setSortBySize function.
# This function controls the Vsearch sorting. Every fasta file created by
# setDereplication is sorted based on abundance. Any reads with a abundance
# lower than strAbundance will be discarded.
def setSortBySize(strDir, strAbundance):
    for strFileName in os.listdir(strDir + "/clustering"):
        if strFileName.startswith("derep"):
            strInputCommand = strDir + "/clustering/" + strFileName
            strOutputCommand = strDir + "/clustering/" + "sorted" + strFileName
            rafSort = sp.Popen(
                [
                    "vsearch",
                    "--sortbysize",
                    strInputCommand,
                    "--output",
                    strOutputCommand,
                    "--minseqlength",
                    "1",
                    "--minsize",
                    strAbundance,
                ],
                stdout=sp.PIPE,
                stderr=sp.PIPE,
            )
            strOut, strError = rafSort.communicate()
        else:
            pass


# The setDereplication function.
# This function controls the Vsearch dereplication. Every fasta file created by
# getFastaFile is dereplicated. This step is necessary for the sorting step to
# work.
def setDereplication(strDir):
    for strFileName in os.listdir(strDir + "/preZip"):
        if strFileName.endswith(".fasta"):
            strInputCommand = strDir + "/preZip/" + strFileName
            strOutputCommand = strDir + "/clustering/" + "derep" + strFileName
            rafDerep = sp.Popen(
                [
                    "vsearch",
                    "--derep_fulllength",
                    strInputCommand,
                    "--output",
                    strOutputCommand,
                    "--minseqlength",
                    "1",
                    "--sizeout",
                ],
                stdout=sp.PIPE,
                stderr=sp.PIPE,
            )
            strOut, strError = rafDerep.communicate()
        else:
            pass


# The getZipArchive function.
# This function creates a zip archive from all files in the specified
# directory.
def getZipArchive(strDir, flZip, strExtension):
    with zipfile.ZipFile(flZip, "w") as objZip:
        for strFileName in os.listdir(strDir):
            if strFileName.endswith(strExtension):
                strFullPath = strDir + "/" + strFileName
                objZip.write(strFullPath, os.path.basename(strFullPath))


# The getFastaFile function.
# This function creates separate fasta files for every unique UMI. The function
# creates a unique name for every UMI file and combines that with the desired
# output path. A file is opened or created based on this combination. The
# read header and the read itself are appended to it.
def getFastaFile(strDir, dicUniqueUmi, strHeader, strRead, strCode):
    strFileIdentifier = (
        "UMI#" + str(dicUniqueUmi[strCode]) + "_" + strCode + ".fasta"
    )
    strFileName = strDir + "/preZip/" + strFileIdentifier
    with open(strFileName, "a") as flOutput:
        flOutput.write(strHeader)
        flOutput.write(strRead)


# The useZeroPosition function.
# This function will isolate either a 5'-end, 3'-end or double UMI based on
# the starting or ending position of a read.
# It will check if both the forward and reverse primer can be found. If this
# check is passed, the 5'-end, 3'-end UMI or double UMI will be isolated by
# adding or subtracting the UMI length from the first or last position of the
# read. The function will return (if possible) the UMI nucleotides.
def useZeroPosition(strSearch, intUmiLength, strRead, strForward, strReverse):
    tplCheckForward = re.search(strForward, strRead)
    if tplCheckForward is not None:
        tplCheckReverse = re.search(strReverse, strRead)
        if tplCheckReverse is not None:
            if strSearch == "umi5":
                return strRead[0 : int(intUmiLength)]
            elif strSearch == "umidouble":
                return (
                    strRead[0 : int(intUmiLength)],
                    strRead[-int(intUmiLength) :],
                )
            elif strSearch == "umi3":
                return strRead[-int(intUmiLength) :]
            else:
                pass
        else:
            pass
    else:
        pass


# The useAdapter function.
# This function searches for a regex string in the provided read. It will
# isolate either a 5'-end, 3'-end or double UMI. The isolation is based on
# this read structure:
#     ADAPTER(F)-UMI(5')-PRIMER(F)-INSERT-PRIMER(R)-UMI(3')-ADAPTER(R).
# When looking for the 5'-end UMI, the last position of ADAPTER(F) is used,
# when looking for the 3'-end UMI, the first position of ADAPTER(R) is used,
# when looking for the double UMI, both mentioned positions are used.
# These positons plus or minus the UMI length result in the UMI nucleotides.
# In the case of umi5 or umi3, a check needs to be passed. This check makes
# sure the opposite adapters are also present, otherwise no UMI is returned.
# The function will return (if possible) the UMI nucleotides.
def useAdapter(strSearch, intUmiLength, strRead, strForward, strReverse):
    if strSearch == "umi5" or strSearch == "umidouble":
        intPositionForward = re.search(strForward, strRead).end()
        intPositionUmiForward = intPositionForward + int(intUmiLength)
        strUmiForward = strRead[intPositionForward:intPositionUmiForward]
        if strSearch == "umi5":
            tplCheckReverse = re.search(strReverse, strRead)
            if tplCheckReverse is not None:
                return strUmiForward
            else:
                pass
        elif strSearch == "umidouble":
            intPositionReverse = re.search(strReverse, strRead).start()
            intPositionUmiReverse = intPositionReverse - int(intUmiLength)
            strUmiReverse = strRead[intPositionUmiReverse:intPositionReverse]
            return strUmiForward, strUmiReverse
        else:
            pass
    elif strSearch == "umi3":
        tplCheckForward = re.search(strForward, strRead)
        if tplCheckForward is not None:
            intPositionReverse = re.search(strReverse, strRead).start()
            intPositionUmiReverse = intPositionReverse - int(intUmiLength)
            strUmiReverse = strRead[intPositionUmiReverse:intPositionReverse]
            return strUmiReverse
        else:
            pass
    else:
        pass


# The usePrimer function.
# This function searches for a regex string in the provided read. It will
# isolate either a 5'-end, 3'-end or double UMI. The isolation is based on
# this read structure:
#     UMI(5')-PRIMER(F)-INSERT-PRIMER(R)-UMI(3').
# When looking for the 5'-end UMI, the first position of PRIMER(F) is used,
# when looking for the 3'-end UMI, the last position of PRIMER(R) is used,
# when looking for the double UMI, both mentioned positions are used.
# These positons plus or minus the UMI length result in the UMI nucleotides.
# In the case of umi5 or umi3, a check needs to be passed. This check makes
# sure the opposite primer are also present, otherwise no UMI is returned.
# The function will return (if possible) the UMI nucleotides.
def usePrimer(strSearch, intUmiLength, strRead, strForward, strReverse):
    if strSearch == "umi5" or strSearch == "umidouble":
        intPositionForward = re.search(strForward, strRead).start()
        intPositionUmiForward = intPositionForward - int(intUmiLength)
        strUmiForward = strRead[intPositionUmiForward:intPositionForward]
        if strSearch == "umi5":
            tplCheckReverse = re.search(strReverse, strRead)
            if tplCheckReverse is not None:
                return strUmiForward
            else:
                pass
        elif strSearch == "umidouble":
            intPositionReverse = re.search(strReverse, strRead).end()
            intPositionUmiReverse = intPositionReverse + int(intUmiLength)
            strUmiReverse = strRead[intPositionReverse:intPositionUmiReverse]
            return strUmiForward, strUmiReverse
        else:
            pass
    elif strSearch == "umi3":
        tplCheckForward = re.search(strForward, strRead)
        if tplCheckForward is not None:
            intPositionReverse = re.search(strReverse, strRead).end()
            intPositionUmiReverse = intPositionReverse + int(intUmiLength)
            strUmiReverse = strRead[intPositionReverse:intPositionUmiReverse]
            return strUmiReverse
        else:
            pass
    else:
        pass


# The getReverseComplement function.
# This function creates a complementary string using a nucleotide string as
# input. The function loops through a list version of the nucleotide string
# and checks/changes every character. The function then returns he new string.
def getReverseComplement(strLine):
    dicComplementCodes = {
        "A": "T",
        "T": "A",
        "G": "C",
        "C": "G",
        "M": "K",
        "R": "Y",
        "W": "W",
        "S": "S",
        "Y": "R",
        "K": "M",
        "V": "B",
        "H": "D",
        "D": "H",
        "B": "V",
        "N": "N",
    }
    lstLine = list(strLine)
    for intPosition in range(len(lstLine)):
        lstLine[intPosition] = dicComplementCodes[lstLine[intPosition]]
    return "".join(lstLine)


# The getRegex function.
# This function creates a regex string using a nucleotide string as input. This
# regex string is based on IUPAC ambiguity codes. The function loops through
# a list version of the nucleotide string and checks per character if it is a
# ambiguous character. If a ambiguous character is found, it is replaced by a
# regex version. The function then returns the new string.
def getRegex(strLine):
    dicAmbiguityCodes = {
        "M": "[AC]",
        "R": "[AG]",
        "W": "[AT]",
        "S": "[CG]",
        "Y": "[CT]",
        "K": "[GT]",
        "V": "[ACG]",
        "H": "[ACT]",
        "D": "[AGT]",
        "B": "[CGT]",
        "N": "[GATC]",
    }
    lstLine = list(strLine)
    for intPosition in range(len(lstLine)):
        if (
            lstLine[intPosition] != "A"
            and lstLine[intPosition] != "T"
            and lstLine[intPosition] != "G"
            and lstLine[intPosition] != "C"
        ):
            lstLine[intPosition] = dicAmbiguityCodes[lstLine[intPosition]]
        else:
            pass
    return "".join(lstLine)


# The getUmi function.
# This function controls the UMI searching approach. It first uses the
# functions getRegex and getReverseComplement to create regex strings of both
# the forward and reverse primers/adapters. The regex strings are then directed
# to the associated approach functions [primer/adapter/zero].
def getUmi(
    strSearch, strApproach, intUmiLength, strForward, strReverse, strRead
):
    strRead = strRead.strip("\n")
    strRegexForward = getRegex(strForward)
    strRegexComplementReverse = getRegex(
        getReverseComplement(strReverse[::-1])
    )
    if strApproach == "primer":
        try:
            return usePrimer(
                strSearch,
                intUmiLength,
                strRead,
                strRegexForward,
                strRegexComplementReverse,
            )
        except AttributeError:
            pass
    elif strApproach == "adapter":
        try:
            return useAdapter(
                strSearch,
                intUmiLength,
                strRead,
                strRegexForward,
                strRegexComplementReverse,
            )
        except AttributeError:
            pass
    elif strApproach == "zero":
        try:
            return useZeroPosition(
                strSearch,
                intUmiLength,
                strRead,
                strRegexForward,
                strRegexComplementReverse,
            )
        except AttributeError:
            pass
    else:
        pass


# The processInputFile function.
# This function opens the input file and loops through it. It stores the read
# header and read nucleotides. For every read the getUmi function is called,
# this outputs one or two UMI codes. In the case of a double UMI search
# [umidouble], the two UMIs are combined. The length of the UMI is checked
# before continuing. The getFastaFile function is called for every read that
# contains a UMI.
def processInputFile(
    flInput,
    strSearch,
    strApproach,
    intUmiLength,
    strForward,
    strReverse,
    strDir,
    strOperand,
):
    dicUniqueUmi = {}
    intUniqueUmi = 1
    with open(flInput) as oisInput:
        for strLine in oisInput:
            if (
                strLine[0] == strOperand
                and bool(re.match("[A-Za-z0-9]", strLine[1])) is True
            ):
                strHeader = strLine
                strRead = next(oisInput)
                try:
                    strUmi = getUmi(
                        strSearch,
                        strApproach,
                        intUmiLength,
                        strForward.upper(),
                        strReverse.upper(),
                        strRead.upper(),
                    )
                except UnboundLocalError:
                    pass
                try:
                    if strUmi is not None:
                        if strSearch == "umi5" or strSearch == "umi3":
                            intLengthPotentialUmi = len(strUmi)
                            if int(intLengthPotentialUmi) == int(intUmiLength):
                                strCode = strUmi
                            else:
                                strCode = None
                        elif strSearch == "umidouble":
                            strCombined = strUmi[0] + strUmi[1]
                            intLengthPotentialUmi = len(strCombined)
                            intDoubleUmi = intUmiLength * 2
                            if int(intLengthPotentialUmi) == int(intDoubleUmi):
                                strCode = strCombined
                            else:
                                strCode = None
                        else:
                            pass
                    else:
                        pass
                except UnboundLocalError:
                    pass
                try:
                    if strCode is not None:
                        if strCode not in dicUniqueUmi:
                            dicUniqueUmi[strCode] = intUniqueUmi
                            intUniqueUmi += 1
                        else:
                            pass
                    else:
                        pass
                except UnboundLocalError:
                    pass
                try:
                    if strCode is not None:
                        getFastaFile(
                            strDir,
                            dicUniqueUmi,
                            strHeader,
                            strRead.upper(),
                            strCode,
                        )
                    else:
                        pass
                except UnboundLocalError:
                    pass
            strUmi = None
            strCode = None


# The runCaltha function.
# This function controls and calls the main functionality of Caltha.
def runCaltha(
    flInput,
    strMainDir,
    strSearch,
    strApproach,
    intUmiLength,
    strForward,
    strReverse,
    strOperand,
    intAbundance,
    fltIdentity,
):
    strInputFileName = flInput.split("/")[-1].split(".")[0]
    strOutputFileName = strMainDir + "/" + strInputFileName
    strTempDir = setWorkDirs(strMainDir, True)
    processInputFile(
        flInput,
        strSearch,
        strApproach,
        intUmiLength,
        strForward,
        strReverse,
        strTempDir,
        strOperand,
    )
    getZipArchive(
        (strTempDir + "/preZip"), (strOutputFileName + "_PREZIP.zip"), ".fasta"
    )
    setDereplication(strTempDir)
    setSortBySize(strTempDir, intAbundance)
    setClusterSize(strTempDir, fltIdentity)
    setOutputFiles(strTempDir, strOutputFileName)
    return "DONE"


# The createInputList function.
# This function creates a list of all input file names. If the input is a zip
# file, it will create a temporary storage directory where the contents of the
# zip file is extracted to.
def createInputList(flInput, strDir, strFormat):
    if strFormat == "zipfasta" or strFormat == "zipfastq":
        strCreate = strDir + "/" + "inputFiles"
        os.mkdir(strCreate)
        with zipfile.ZipFile(flInput, "r") as objZip:
            objZip.extractall(strCreate)
            lstInput = [
                strCreate + "/" + strFile
                for strFile in zipfile.ZipFile.namelist(objZip)
            ]
    elif strFormat == "fasta" or strFormat == "fastq":
        lstInput = [flInput]
    return lstInput


# The setWorkDirs function.
# This function creates the main temporary directory and all subprocess
# directories. It checks if the directories already exist and creates them
# if they don't.
def setWorkDirs(strDir, blnExtra):
    strRandom = "".join(
        random.choice(string.ascii_lowercase) for i in range(10)
    )
    lstWorkDirs = []
    lstWorkDirs.append(strDir + "/" + strRandom)
    if blnExtra is True:
        lstWorkDirs.append(strDir + "/" + strRandom + "/preZip")
        lstWorkDirs.append(strDir + "/" + strRandom + "/clustering")
        lstWorkDirs.append(strDir + "/" + strRandom + "/postClustering")
    else:
        pass
    for strDirectory in lstWorkDirs:
        if not os.path.exists(strDirectory):
            os.mkdir(strDirectory)
        else:
            pass
    return lstWorkDirs[0]


# The setFormat function.
# This function specifies the first character of the read headers. This
# character is based on the input file format. The function then returns this
# character.
def setFormat(strFormat):
    if strFormat == "fasta" or strFormat == "zipfasta":
        return ">"
    elif strFormat == "fastq" or strFormat == "zipfastq":
        return "@"
    else:
        pass


# The argvs function.
def parseArgvs():
    strDescription = "A python package to process UMI tagged mixed amplicon\
                      metabarcoding data."
    strEpilog = "This python package requires one extra dependency which can\
                 be easily installed with conda (conda install -c bioconda\
                 vsearch)."
    parser = argparse.ArgumentParser(
        description=strDescription,
        epilog=strEpilog,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "-v", "-version", action="version", version="%(prog)s [0.4]"
    )
    parser.add_argument(
        "-i",
        "-input",
        action="store",
        dest="fisInput",
        default=argparse.SUPPRESS,
        help="The location of the input fasta/fastq file(s).",
    )
    parser.add_argument(
        "-t",
        "-tabular",
        action="store",
        dest="fosTabular",
        default=argparse.SUPPRESS,
        help="The location of the output tabular zip file.",
    )
    parser.add_argument(
        "-z",
        "-zip",
        action="store",
        dest="fosPreZip",
        default=argparse.SUPPRESS,
        help="The location of the pre validation zip file.",
    )
    parser.add_argument(
        "-b",
        "-blast",
        action="store",
        dest="fosBlast",
        default=argparse.SUPPRESS,
        help="The location of the output blast zip file.",
    )
    parser.add_argument(
        "-f",
        "-format",
        action="store",
        dest="disFormat",
        default=argparse.SUPPRESS,
        help="The format of the input file [fasta/fastq/zipfasta/zipfastq].",
    )
    parser.add_argument(
        "-s",
        "-search",
        action="store",
        dest="disSearch",
        default="umi5",
        help="Search UMIs at the 5'-end\
              [umi5], 3'-end [umi3] or at the 5'-end and 3'-end [umidouble].",
    )
    parser.add_argument(
        "-a",
        "-approach",
        action="store",
        dest="disApproach",
        default="primer",
        help="The UMI search approach [primer/adapter/zero].",
    )
    parser.add_argument(
        "-u",
        "-length",
        action="store",
        dest="disUmiLength",
        default=argparse.SUPPRESS,
        help="The length of the UMI sequence.",
    )
    parser.add_argument(
        "-y",
        "-identity",
        action="store",
        dest="disIdentity",
        default="0.97",
        help="The identity percentage with which to perform the validation.",
    )
    parser.add_argument(
        "-c",
        "-abundance",
        action="store",
        dest="disAbundance",
        default="1",
        help="The minimum abundance of a read\
              in order to be included during validation.",
    )
    parser.add_argument(
        "-w",
        "-forward",
        action="store",
        dest="disForward",
        default=argparse.SUPPRESS,
        help="The 5'-end search nucleotides.",
    )
    parser.add_argument(
        "-r",
        "-reverse",
        action="store",
        dest="disReverse",
        default=argparse.SUPPRESS,
        help="The 3'-end search nucleotides.",
    )
    parser.add_argument(
        "-d",
        "-directory",
        action="store",
        dest="fisDirectory",
        default=".",
        help="The location where the temporary working directory will\
              be created.",
    )
    parser.add_argument(
        "-p",
        "-processes",
        action="store",
        dest="disProcesses",
        default=mp.cpu_count(),
        help="The number of threads/cores/processes to simultaneously run\
              Caltha with.",
    )
    argvs = parser.parse_args()
    return argvs


# The main function.
def main():
    argvs = parseArgvs()
    strOperand = setFormat(argvs.disFormat)
    strMainDir = setWorkDirs(argvs.fisDirectory, False)
    lstInput = createInputList(argvs.fisInput, strMainDir, argvs.disFormat)
    mpPool = mp.Pool(int(argvs.disProcesses))
    mpTemporary = [
        mpPool.apply_async(
            runCaltha,
            args=(
                flInput,
                strMainDir,
                argvs.disSearch,
                argvs.disApproach,
                argvs.disUmiLength,
                argvs.disForward,
                argvs.disReverse,
                strOperand,
                argvs.disAbundance,
                argvs.disIdentity,
            ),
        )
        for flInput in lstInput
    ]
    mpPool.close()
    mpPool.join()
    getZipArchive(strMainDir, argvs.fosTabular, ".tbl")
    getZipArchive(strMainDir, argvs.fosPreZip, ".zip")
    getZipArchive(strMainDir, argvs.fosBlast, ".fasta")
    removeWorkDirs(strMainDir)
    print(mpTemporary)


if __name__ == "__main__":
    main()
