#!python

"""Prep coding sequence alignments for ``phydms``.

Written by Jesse Bloom."""


import os
import re
import time
import subprocess
import matplotlib
import matplotlib.pyplot as plt
import Bio.SeqIO
import Bio.Data.CodonTable
import phydmslib.parsearguments
import phydmslib.file_io
matplotlib.use('pdf')
plt.style.use('ggplot')

_cleanheader = re.compile('[\s\:\;\(\)\[\]\,\'\"]')


def CleanHeader(head):
    """Returns copy of header with problem chars replaced by underscore."""
    return _cleanheader.sub('_', head)


def MAFFT_CDSandProtAlignments(headseqprots, mafftcmd, tempfileprefix):
    """Uses MAFFT to align CDS sequences from protein sequences.

    *headseqprots* is a list of *(header, ntseq, protseq)* strings.

    *mafftcmd* is command to run MAFFT.

    *tempfileprefix* prefix for temporary files that are created,
    then destroyed if function completes normally.

    The return value is *alignedheadseqprots*, which is like
    *headseqprots* except coding and protein sequences are aligned.
    The sequence order is preserved in this list.
    """
    # mafft shortens sequence names to the first 253 characters
    headseqprots = [(head[:254], seq, prot) for
                    (head, seq, prot) in headseqprots]
    headseqprotsnames = [head for head, seq, prot in headseqprots]
    assert len(headseqprotsnames) == len(set(headseqprotsnames)),\
           "The first 253 characters of each sequence name must be unique."

    seq_d = dict([(head, seq) for (head, seq, prot) in headseqprots])
    prot_d = dict([(head, prot) for (head, seq, prot) in headseqprots])

    (tempfiledir, tempfilebase) = (os.path.dirname(tempfileprefix),
                                   os.path.basename(tempfileprefix))
    if tempfiledir:
        tempfileprefix = tempfiledir + '/_' + tempfilebase
    else:
        tempfileprefix = tempfilebase
    unalignedprotsfile = tempfileprefix + '_unaligned_mafft_prots.fasta'
    alignedprotsfile = tempfileprefix + '_aligned_mafft_prots.fasta'
    outputfile = tempfileprefix + '_mafft_output.txt'
    with open(unalignedprotsfile, 'w') as f:
        f.write('\n'.join(['>{0}\n{1}'.format(head, prot)
                           for (head, seq, prot) in headseqprots]))

    with open(alignedprotsfile, 'w') as stdout,\
         open(outputfile, 'w') as stderr:
        subprocess.check_call(
                [mafftcmd, '--auto', '--thread', '-1', unalignedprotsfile],
                stdout=stdout,
                stderr=stderr,
                )

    alignedseqprots_by_head = {}
    for alignedprotrecord in Bio.SeqIO.parse(alignedprotsfile, 'fasta'):
        head = alignedprotrecord.description
        alignedprot = str(alignedprotrecord.seq)
        prot = prot_d[head]
        seq = seq_d[head]
        assert ('-' not in prot) and ('-' not in seq),\
               "Already gaps in seq or prot"
        assert len(alignedprot) >= len(prot)
        assert len(alignedprot.replace('-', '')) == len(prot)
        assert len(seq) == len(prot) * 3
        alignedseq = []
        i = 0  # keeps track of position in unaligned prot
        for aa in alignedprot:
            if aa == '-':
                alignedseq.append('---')
            else:
                assert aa == prot[i]
                alignedseq.append(seq[3 * i: 3 * i + 3])
                i += 1
        alignedseqprots_by_head[head] = (''.join(alignedseq), alignedprot)

    # get back to original input order
    alignedheadseqprots = []
    for (head, seq, prot) in headseqprots:
        (alignedseq, alignedprot) = alignedseqprots_by_head[head]
        alignedheadseqprots.append((head, alignedseq, alignedprot))

    for f in [unalignedprotsfile, alignedprotsfile, outputfile]:
        os.remove(f)

    return alignedheadseqprots


def main():
    """Main body of script."""

    # Parse command line arguments
    parser = phydmslib.parsearguments.PhyDMSPrepAlignmentParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # print some basic information
    print('\nBeginning execution of %s in directory %s at time %s\n'
          % (prog, os.getcwd(), time.asctime()))
    print("%s\n" % phydmslib.file_io.Versions())
    print('Parsed the following command-line arguments:\n%s\n' % '\n'
          .join(['\t%s = %s' % tup for tup in args.items()]))

    plot = os.path.splitext(args['alignment'])[0] + '.pdf'
    for f in [plot, args['alignment']]:
        if os.path.isfile(f):
            print("Removing existing output file {0}".format(f))
            os.remove(f)

    # get keepseqs and purgeseqs as sets
    for argname in ['keepseqs', 'purgeseqs']:
        argvalue = args[argname]
        if argvalue:
            if len(argvalue) == 1 and \
               argvalue[0] and os.path.isfile(argvalue[0]):
                with open(argvalue[0]) as f:
                    args[argname] = set([line.strip() for line in f.readlines()
                                         if line and not line.isspace()])
                print("Read {0} sequences from the --{1} file {2}"
                      .format(len(args[argname]), argname, argvalue[0]))
            else:
                args[argname] = set([x for x in argvalue if x])
                print("There are {0} sequences specified by --{1}"
                      .format(len(args[argname]), argname))
        else:
            args[argname] = set([])

    # read sequences
    seqs = []
    for seq in Bio.SeqIO.parse(args['inseqs'], 'fasta'):
        seq.seq = seq.seq.upper()
        seqs.append(seq)
    print("Read {0} sequences from {1}".format(len(seqs), args['inseqs']))

    # sequences to purge
    cleanseqs = []
    for seq in seqs:
        head = seq.description
        cleanhead = CleanHeader(head)
        for purgename in args['purgeseqs']:
            if (purgename in head) or (purgename in cleanhead):
                for keepname in args['keepseqs']:
                    if (keepname in head) or (keepname in cleanhead):
                        raise ValueError("--purgseqs and --keepseqs conflict "
                                         "about what to do with {0}"
                                         .format(head))
                break
        else:
            cleanseqs.append(seq)
    seqs = cleanseqs
    print("Retained {0} after removing those specified for purging by "
          "'--purgeseqs.'".format(len(seqs)))

    # remove gaps if sequences not pre-aligned
    if not args['prealigned']:
        for (i, seq) in enumerate(seqs):
            seqs[i].seq = seqs[i].seq.ungap(gap='-')

    # check length divisible by 3
    seqs = [seq for seq in seqs if len(seq) % 3 == 0]
    print("Retained {0} sequences after removing any with length not multiple "
          "of 3.".format(len(seqs)))

    # purge ambiguous nucleotide sequences
    seqmatch = re.compile('^[ACTG-]+$')
    seqs = [seq for seq in seqs if seqmatch.search(str(seq.seq))]
    print("Retained {0} sequences after purging any with ambiguous "
          "nucleotides.".format(len(seqs)))

    # make sure all sequences translate without premature stops
    prots = []
    remove_indices = []
    for i in range(len(seqs)):
        try:
            prots.append(str(seqs[i].seq.translate(gap='-', stop_symbol='*')))
        except Bio.Data.CodonTable.TranslationError:
            remove_indices.append(i)  # doesn't translate
    for i in sorted(remove_indices, reverse=True):
        del seqs[i]
    assert len(prots) == len(seqs)
    remove_indices = []
    if args['prealigned']:
        for i in range(len(seqs)):
            prot = prots[i]
            if (prot.count('*') == 1) and prot.replace('-', '')[-1] == '*':
                # convert terminal stop to gap
                # trimmed later if all sites gapped
                istop = prot.index('*')
                prot = list(prots[i])
                prot[istop] = '-'
                prots[i] = ''.join(prot)
                # introduce gap requires making seq mutable
                iseq = seqs[i].seq.tomutable()
                iseq[3 * istop: 3 * istop + 3] = '---'
                seqs[i].seq = iseq.toseq()
            elif prot.count('*') >= 1:
                # premature stop
                remove_indices.append(i)
    else:
        for i in range(len(seqs)):
            prot = prots[i]
            if prot.count('*') == 1 and prot[-1] == '*':
                # trim terminal stop
                prots[i] = prots[i][: -1]
                seqs[i].seq = seqs[i].seq[: -3]
            elif prot.count('*') >= 1:
                # premature stop
                remove_indices.append(i)
    for i in sorted(remove_indices, reverse=True):
        del prots[i]
        del seqs[i]
    print("Retained {0} sequences after purging any with premature stops or "
          "that are otherwise un-translateable.".format(len(seqs)))

    # get reference sequence
    refseq = [seq for seq in seqs if (args['refseq'] in seq.description) or
              (CleanHeader(args['refseq']) in seq.description)]
    if len(refseq) == 1:
        refseq = refseq[0]
        print("Using the following as reference sequence: {0}"
              .format(refseq.description))
    elif len(refseq) < 1:
        raise ValueError("Failed to find any sequence with a header that "
                         "contained refseq identifier of {0}\nWas reference "
                         "sequence purged due to not being translate-able?"
                         .format(args['refseq']))
    else:
        raise ValueError("Found multiple sequences with headers that "
                         "contained refseq identifier of {0}"
                         .format(args['refseqs']))

    # For removing sequences, we make two lists: sequences to always keeps,
    # then other sequences ordered by frequency that protein occurs.
    # This allows us to preferentially retain common sequences.
    seqs_alwayskeep = []  # entries are (head, seq, prot) as strings
    protcounts = {}  # keyed by protein, values are lists of (head, seq, prot)
    foundrefseq = False
    assert len(seqs) == len(prots)
    for (seqrecord, prot) in zip(seqs, prots):
        head = seqrecord.description
        cleanhead = CleanHeader(head)
        seq = str(seqrecord.seq)
        if head == refseq.description:
            assert not foundrefseq, "Duplicate reference sequences"
            foundrefseq = True
            seqs_alwayskeep.insert(0, (cleanhead, seq, prot))
        else:
            for keepsubstring in args['keepseqs']:
                if (keepsubstring in head) or (keepsubstring in cleanhead):
                    seqs_alwayskeep.append((cleanhead, seq, prot))
                    break
            else:
                if prot in protcounts:
                    protcounts[prot].append((cleanhead, seq, prot))
                else:
                    protcounts[prot] = [(cleanhead, seq, prot)]
    assert len(seqs) ==\
           len(seqs_alwayskeep) + sum([len(x) for x in protcounts.values()]),\
           "Somehow lost sequences when looking for unique seqs."
    assert foundrefseq, "Never found refseq to add to seqs_alwayskeep"
    # get one unique copy of each protein sorted by number of copies
    seqs_unique = sorted([(len(x), x[0]) for x in protcounts.values()],
                         reverse=True)
    seqs_unique = [seqtup for (count, seqtup) in seqs_unique]
    print("Purged sequences encoding redundant proteins, being sure to retain "
          "reference sequence and any specified by '--keepseqs'. Overall, {0} "
          "sequences remain.".format(len(seqs_unique) + len(seqs_alwayskeep)))
    heads = [tup[0] for tup in seqs_unique] +\
            [tup[0] for tup in seqs_alwayskeep]
    assert len(heads) == len(set(heads)),\
           ("Found sequences with non-unique cleaned headers:\n%{0}"
           .format('\n'.join([head for head in heads if heads.count(head) > 1]
                             )))

    if args['prealigned']:
        print("You specified '--prealigned', "
              "so the sequences are NOT being aligned.")
        # make sure sequences all the same length
        seqlengths = [len(tup[1]) for tup in seqs_unique + seqs_alwayskeep]
        assert all([seqlengths[0] == n for n in seqlengths]),\
               ("You indicated sequences are pre-aligned with the "
                "'--prealigned option, but they are not all of the same "
                "length.")
    else:
        # align all sequences
        try:
            mafftversion = (subprocess.check_output([args['mafft'],
                                                    '--version'],
                                                    stderr=subprocess.STDOUT)
                            .strip())
            print("Aligning sequences with MAFFT using {0} {1}"
                  .format(args['mafft'], mafftversion))
        except OSError:
            raise ValueError("The mafft command of {0} is not valid. "
                             "Is mafft installed at this path?"
                             .format(args['mafft']))
        alignment = MAFFT_CDSandProtAlignments(seqs_unique + seqs_alwayskeep,
                                               args['mafft'], args['alignment']
                                               )
        seqs_unique = alignment[: len(seqs_unique)]
        seqs_alwayskeep = alignment[len(seqs_unique):]
        seqlengths = [len(tup[1]) for tup in seqs_unique + seqs_alwayskeep]
        assert all([seqlengths[0] == n for n in seqlengths]),\
               "MAFFT alignment failed to yield sequences all of the\
               same length."

    # strip gaps relative refseq
    assert seqs_alwayskeep[0][0] == CleanHeader(refseq.description),\
           ("Expected refseq to have this header:\t{0}\nFound this:\t{1}"
           .format(refseq.description, seqs_alwayskeep[0][0]))
    refseqprot = seqs_alwayskeep[0][2]
    gapped = set([i for i in range(len(refseqprot)) if refseqprot[i] == '-'])
    for xlist in [seqs_alwayskeep, seqs_unique]:
        for (i, (head, seq, prot)) in enumerate(xlist):
            strippedprot = ''.join([aa for (j, aa) in enumerate(prot)
                                    if j not in gapped])
            strippedseq = ''.join([seq[3 * j: 3 * j + 3] for j in
                                   range(len(prot)) if j not in gapped])
            xlist[i] = (head, strippedseq, strippedprot)
    refseqprot = seqs_alwayskeep[0][2]  # with gaps removed
    assert '-' not in refseqprot, "Gap stripping from reference sequence \
                                   failed."
    refseq = seqs_alwayskeep[0][1]  # with gaps removed
    assert '-' not in refseq, "Gap stripping from reference sequence failed."
    print("After stripping gaps relative to reference sequence, all proteins "
          "are of length {0}".format(len(refseqprot)))

    # make seqs_unique,
    # seq_alwayskeep be *(head, seq, prot, ntident, protident)*
    refseqprotlength = float(len(refseqprot))
    refseqlength = float(len(refseq))
    for xlist in [seqs_alwayskeep, seqs_unique]:
        for (i, (head, seq, prot)) in enumerate(xlist):
            protident = (sum([x == y for (x, y) in zip(prot, refseqprot)]) /
                         refseqprotlength)
            ntident = (sum([x == y for (x, y) in zip(seq, refseq)]) /
                       refseqlength)
            xlist[i] = (head, seq, prot, ntident, protident)

    # purge sequences that don't meet identity threshold
    purged_idents = []  # list of *(ntident, protident)* for purged sequences
    for xlist in [seqs_unique]:
        remove_indices = []
        for (i, (head, seq, prot, ntident, protident)) in enumerate(xlist):
            if protident < args['minidentity']:
                purged_idents.append((ntident, protident))
                remove_indices.append(i)
        for i in sorted(remove_indices, reverse=True):
            del xlist[i]
    print("Retained {0} sequences after purging those with < {1} protein "
          "identity to reference sequence."
          .format(len(seqs_alwayskeep + seqs_unique), args['minidentity']))

    # purge sequences that don't meet uniqueness threshold
    alignment = seqs_alwayskeep
    for (head, seq, prot, ntident, protident) in seqs_unique:
        for (head2, seq2, prot2, ntident2, protident2) in alignment:
            protdiffs = sum([x != y for (x, y) in zip(prot, prot2)])
            if protdiffs < args['minuniqueness']:
                purged_idents.append((ntident, protident))
                break
        else:
            alignment.append((head, seq, prot, ntident, protident))
    print("Retained {0} sequences after purging those without at least {1} "
          "amino-acid differences with other retained sequences."
          .format(len(alignment), args['minuniqueness']))

    # write final alignment of coding sequences
    print("Writing the final alignment of {0} coding sequences to {1}"
          .format(len(alignment), args['alignment']))
    with open(args['alignment'], 'w') as f:
        for (head, seq, prot, ntident, protident) in alignment:
            f.write('>{0}\n{1}\n'.format(head, seq))

    # make plot
    retained_idents = [(ntident, protident) for
                       (head, seq, prot, ntident, protident) in alignment]
    print("Plotting retained and purged sequences to {0}".format(plot))
    plt.figure()
    points = []
    labels = []
    for (label, idents, color, marker, alpha) in [
            ('purged', purged_idents, 'b', 'o', 0.25),
            ('retained', retained_idents, 'r', 'd', 0.5)
            ]:
        pt = plt.scatter(
                [tup[0] for tup in idents],
                [tup[1] for tup in idents],
                c=color,
                marker=marker,
                s=30,
                alpha=alpha)
        points.append(pt)
        labels.append(label)
    plt.xlabel('DNA identity relative to refseq', fontsize=14)
    plt.ylabel('protein identity relative to refseq', fontsize=14)
    plt.legend(points, labels, scatterpoints=1, loc='upper left', fontsize=14)
    plt.savefig(plot)

    # program done
    print('\nSuccessful completion of %s' % prog)


if __name__ == '__main__':
    main()  # run the script
