#!/usr/bin/env python

import numpy as np
import scipy.stats as st
from pickle import load,loads,dump,dumps
import click
import blosc

import os
from pathlib import Path
if 'PROSTDIR' in os.environ: prostdir = os.environ['PROSTDIR']
else: prostdir = str(Path.home())+'/.config/prost'

from itertools import groupby
def fasta_iter(fastafile):
    fh = open(fastafile)
    faiter = (x[1] for x in groupby(fh, lambda line: line[0] == ">"))
    for header in faiter:
        header = next(header)[1:].strip()
        seq = "".join(s.strip() for s in next(faiter))
        yield header, seq


def check_seq(seq):
    std = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
    ambiguous = [ 'X', 'B', 'U', 'Z', 'O']
    aa = std+ambiguous
    for a in seq.upper():
        if a not in aa:
            return False,a
    return True,''

@click.command()
@click.argument('fasta', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('out', type=click.Path(exists=False,file_okay=True,dir_okay=False))
def createdb(fasta, out):
    '''PROST python package v0.1 createdb command.
createdb command gets a fasta file and creates a PROST database that can be used as querty or taget database in a search.'''
    from pyprost import quantSeq

    if os.path.exists(prostdir+'/cache.pkl'):
        with open(prostdir+'/cache.pkl','rb') as f:
            cache = load(f)
    else:
        cache = {}
    cacheDirty = False

    quant = []
    namesd = {}
    ind = 0

    for fa in fasta_iter(fasta):
        name = fa[0]

        l = len(fa[1])
        if l < 5:
            print(name,'discarded, length:',l)
            continue

        status,offchar = check_seq(fa[1])
        if status == False:
            print(name,'contains unknown aa',offchar)
            continue

        if name in namesd:
            print(name,'is already exits!')
            assert np.shape(quant[namesd[name]])[0] == 475
            continue

        namesd[name] = ind
        ind += 1

        if fa[1] in cache:
            quant.append(cache[fa[1]])
        else:
            print(name,'not found in cache. Quantize it.')
            qseq = quantSeq(fa[1])
            quant.append(qseq)
            cache[fa[1]] = qseq
            cacheDirty = True

        assert np.shape(quant[-1])[0] == 475

    names = list(namesd.keys())

    assert len(names) == np.shape(quant)[0]
    print('Total number of sequences embedded in the db:',len(names))

    with open(out,'wb') as f:
        f.write(blosc.compress(dumps([np.array(names),np.array(quant)])))

    if cacheDirty:
        with open(prostdir+'/cache.pkl','wb') as f:
            dump(cache,f)

def _search(thr, querydb, targetdb, out):
    with open(querydb,'rb') as f:
        qnames,qdb = loads(blosc.decompress(f.read()))
    with open(targetdb,'rb') as f:
        tnames,tdb = loads(blosc.decompress(f.read()))
    ldb = len(tdb)
    output = []
    
    mem = np.zeros((ldb,475),dtype='int8')
    for i,q in enumerate(qdb):
        np.subtract(tdb,q,out=mem)
        np.absolute(mem,out=mem)
        dbdiff = mem.sum(axis=1)
        m=np.median(dbdiff)
        s=st.median_abs_deviation(dbdiff)*1.4826
        zscore = (dbdiff-m)/s
        e = st.norm.cdf(zscore)*ldb
        res = np.where(e < thr)[0]
        sort = np.argsort(e[res])
        res = res[sort]
        dbdiff = dbdiff[res]/2
        evals = e[res]
        names = tnames[res]

        for n,diff,ev in zip(names,dbdiff,evals):
            output.append('%s\t%s\t%d\t%.2e'%(qnames[i],n,diff,ev))

    with open(out,'w') as f:
        for o in output:
            f.write(o+'\n')

@click.command()
@click.option('--thr', default=0.05, help='E-value threshold for homolog detection')
@click.argument('querydb', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('targetdb', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('out', type=click.Path(exists=False,file_okay=True,dir_okay=False))
def search(thr, querydb, targetdb, out):
    '''PROST python package v0.1 search command.
This command searches a query database against a target database.
Both databases should be created using createdb command. 
Databases can contain one or more sequences.
An e-value threshold can be specified with --thr flag. The default e-value threshold is 0.05'''
    _search(thr,querydb,targetdb,out)

@click.command()
@click.option('--thr', default=0.05, help='E-value threshold for homolog detection')
@click.argument('querydb', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('out', type=click.Path(exists=False,file_okay=True,dir_okay=False))
def searchsp(thr, querydb, out):
    '''PROST python package v0.1 search SwissProt command.
This command searches a query database against a SwissProt January 2022 database.
Query database should be created using createdb command. 
It can contain one or more sequences.
An e-value threshold can be specified with --thr flag. The default e-value threshold is 0.05'''
    _search(thr,querydb,prostdir+'/sp.01.22.prdb',out)

@click.group()
def cli():
    '''PROST python package v0.1.
Please specify a command.
createdb: creates a PROST database from given fasta file. The fasta file usually contains more than one entry.
search: searches a query database against a target database. Query database can contain one or more sequences embedded using createdb command.
searchsp: searches a query database against SwissProt January 2022 database. Query database can contain one or more sequences embedded using createdb command.'''
    pass

cli.add_command(createdb)
cli.add_command(search)
cli.add_command(searchsp)

if __name__ == '__main__':
    cli()
