#!/usr/bin/env python3

import argparse
import os

import torch
import numpy as np
from pytorch_lightning import Trainer
from deepblast.trainer import LightningAligner
from deepblast.dataset import FastaDataset
from deepblast.dataset.utils import collate_fasta_f
from deepblast.dataset.utils import unpack_sequences
from torch.utils.data import DataLoader
from tqdm import tqdm


def fasta_dataloader(args):
    dataset = FastaDataset(args.query_fasta, args.db_fasta)
    dataloader = DataLoader(
        dataset, args.batch_size, shuffle=False,
        collate_fn=collate_fasta_f, num_workers=args.num_workers,
        pin_memory=True)
    return dataloader


def main(args):
    print('args', args)
    model = LightningAligner.load_from_checkpoint(
        args.load_from_checkpoint)
    dataloader = fasta_dataloader(args)
    if args.gpu:
        model = model.cuda()
    with open(args.output_file, 'w') as out_handle:
        for batch in tqdm(dataloader):
            qids, dids, seqs, order = batch
            if args.gpu:
                seqs = seqs.cuda()
            A = model.align.score(seqs, order)
            _, qlen, _, dlen = unpack_sequences(seqs, order)
            lens = zip(qlen, dlen)
            for i, ln in enumerate(lens):
                ql, dl = ln
                aln_score = np.asscalar(A[i].detach().cpu().numpy())
                norm_score = aln_score / (float(ql) * float(dl))
                q_id, d_id = qids[i], dids[i]
                res = [q_id, d_id, np.round(aln_score, decimals=4),
                       np.round(norm_score, decimals=4)]
                res = list(map(str, res))
                out_handle.write('\t'.join(res) + '\n')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--query-fasta', type=str, required=True)
    parser.add_argument('--db-fasta', type=str, required=True)
    parser.add_argument('--load-from-checkpoint', type=str, required=True)
    parser.add_argument('--output-file', type=str, required=True)
    parser.add_argument('--gpu', type=bool, default=None)
    parser.add_argument('--num-workers', type=int, default=1)
    parser.add_argument('--batch-size', type=int, default=10)

    hparams = parser.parse_args()
    main(hparams)
