#!/usr/bin/env python3
import argparse
import os
import pandas as pd

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from deepblast.trainer import DeepBLAST
dfrom deepblast.dataset.parse_pdb import readPDB
# from deepblast.dataset.parse_mali import read_mali
from transformers import T5EncoderModel, T5Tokenizer


def main(args):

    # res = read_mali(args.input_mali_dir, report_ids=True)
    res = pd.read_csv(args.mali_pairs, index_col=0)
    tokenizer = T5Tokenizer.from_pretrained(
        args.pretrain_path, do_lower_case=False)
    lm = T5EncoderModel.from_pretrained(args.pretrain_path)
    model = DeepBLAST.load_from_checkpoint(
        args.load_from_checkpoint, lm=lm, tokenizer=tokenizer,
        alignment_mode=args.alignment_mode)

    if args.accelerator == 'gpu':
        model = model.cuda()

    dps = []
    for i in range(len(res)):
        pdb0 = res.iloc[i][0]
        pdb1 = res.iloc[i][1]
        _, fpnts0 = readPDB(f'{args.input_mali_dir}/{pdb0}')
        _, fpnts1 = readPDB(f'{args.input_mali_dir}/{pdb1}')
        x = fpnts0.seq
        y = fpnts1.seq
        # x = res.iloc[i][0]
        # y = res.iloc[i][1]
        s = model.align(y, x)
        dps.append(s)
    res['deepblast'] = dps
    res = res.rename(
        columns = {'0' : 'query_seq', '1': 'hit_seq', '2': 'manual'})
    res.to_csv(args.output_alignments)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--accelerator', type=str, default='cpu')
    parser.add_argument('--alignment-mode', type=str, default='needleman-wunsch')
    parser.add_argument('--load-from-checkpoint', type=str, default=None)
    parser.add_argument('--input-mali-dir', type=str, default=None,
                        help="Root directory for Mali PDB files.")
    parser.add_argument('--mali-pairs', type=str, default=None,
                        help="Specifies file paths for pdb pairs.")
    parser.add_argument('--pretrain-path', type=str, default=None,
                        help="Path to pretrained model.")
    parser.add_argument('--output-alignments', type=str, default=None,
                        help="Output alignments.")
    hparams = parser.parse_args()
    main(hparams)
