#!/usr/bin/env python
'''
Simple Smith-Waterman aligner
'''

import sys
import os
try:
    import swalign
except:
    sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..'))
    import swalign

def usage():
    sys.stderr.write(__doc__)
    sys.stderr.write('''
Usage: swalign {options} ref query

Reference and query arguments can either be written on the command-line, read
from stdin, or read as FASTA format files. If there is more than one sequence
in the reference FASTA file, the query will be aligned to all reference
sequences and only the best scoring alignment will be displayed. If more than
one sequence is in a query FASTA file, each query sequence will be aligned to
the reference.

Alignments will be made in both forward and reverse directions.

Options:
  -m N              Match score (default: 2)
  -mm N             Mismatch penalty (default: 1)
  -gap N            Gap penalty (default: 1)
  -gapext N         Gap extension penalty (default: 1)
  -gapdecay N       Decay the gap extension penalty (default: 0.0)
  -wrap N           Wrap alignments when they are longer than N bases
  -global           Perform a global alignment (experimental)
  -query            Align the full query sequence (mix of local/global)
  -summary fname    Write a summary files of match locations (tab-delimited)
  -useregion        Use regions for coordinates if included in FASTA ref

Example:
    ~$ swalign AAGGGGAGGACGATGCGGATGTTC AGGGAGGACGATGCGG

''')
    sys.exit(1)


if __name__ == '__main__':
    ref = None
    query = None

    match = 2
    mismatch = -1
    gap_penalty = -1
    gap_extension_penalty = -1
    gap_extension_decay = 0.0
    wrap = None
    verbose = False
    globalalign = False
    useregion = False
    summary = False
    progress = False
    full_query = False

    last = None

    for arg in sys.argv[1:]:
        if last == '-m':
            match = int(arg)
            last = None
        elif last == '-mm':
            mismatch = -int(arg)
            if mismatch > 0:
                mismatch = -mismatch
            last = None
        elif last == '-gap':
            gap_penalty = float(arg)
            if gap_penalty > 0:
                gap_penalty = -gap_penalty
            last = None
        elif last == '-gapext':
            gap_extension_penalty = float(arg)
            if gap_extension_penalty > 0:
                gap_extension_penalty = -gap_extension_penalty
            last = None
        elif last == '-gapdecay':
            gap_extension_decay = float(arg)
            if gap_extension_decay < 0:
                gap_extension_decay = -gap_extension_decay
            last = None
        elif last == '-wrap':
            wrap = int(arg)
            last = None
        elif last == '-summary':
            summary = arg
            last = None
        elif arg in ['-m', '-mm', '-gap', '-gapext', '-gapdecay', '-wrap', '-summary']:
            last = arg
        elif arg == '-progress':
            progress = True
        elif arg == '-global':
            globalalign = True
        elif arg == '-v':
            verbose = True
        elif arg == '-query':
            full_query = True
        elif arg == '-useregion':
            useregion = True
        elif not ref:
            if os.path.exists(arg) or arg == '-':
                ref = swalign.fasta_gen(arg)
            else:
                ref = swalign.seq_gen('cmdline', arg)
        elif not query:
            if os.path.exists(arg) or arg == '-':
                query = swalign.fasta_gen(arg)
            else:
                query = swalign.seq_gen('cmdline', arg)

    if not ref or not query:
        usage()

    sw = swalign.LocalAlignment(
        swalign.NucleotideScoringMatrix(match, mismatch),
        gap_penalty, gap_extension_penalty,
        gap_extension_decay=gap_extension_decay, verbose=verbose, globalalign=globalalign, full_query=full_query)

    summarized = []

    for q_name, q_seq, q_comments in query():
        if progress:
            sys.stderr.write('%s: ' % q_name)

        best = None
        best_r_comments = None
        for r_name, r_seq, r_comments in ref():
            if progress:
                sys.stderr.write('.')
                sys.stderr.flush()

            for strand in '+-':
                if strand == '-':
                    aln = sw.align(r_seq, swalign.revcomp(q_seq),
                        ref_name=r_name, query_name=q_name, rc=True)

                else:
                    aln = sw.align(r_seq, q_seq, ref_name=r_name,
                        query_name=q_name)

                if not best or aln.score > best.score:
                    best = aln
                    best_r_comments = r_comments

        if progress:
            sys.stderr.write('\n')
            sys.stderr.flush()

        if useregion:
            ref_start = swalign.extract_region(best_r_comments)
            if ref_start:
                best.set_ref_offset(*ref_start)

        best.dump(wrap)

        if summary:
            summarized.append((q_name, best.r_name, best.q_pos, best.q_end, best.r_pos + best.r_offset, best.r_end + best.r_offset, '-' if best.rc else '+', best.cigar_str))
        print ""

    if summary:
        with open(summary, 'w') as f:
            f.write("query\tref\tquery_start\tquery_end\tref_start\tref_end\tstrand\tcigar\n")
            for s in summarized:
                f.write('%s\n' % '\t'.join([str(x) for x in s]))
