#!/usr/bin/env python3

import argparse
import os

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from deepblast.trainer import DeepBLAST
from transformers import T5EncoderModel, T5Tokenizer


def main(args):


    print('args', args)
    tokenizer = T5Tokenizer.from_pretrained(
        args.pretrain_path, do_lower_case=False)
    lm = T5EncoderModel.from_pretrained(args.pretrain_path)
    if args.load_from_checkpoint is not None:
        model = DeepBLAST.load_from_checkpoint(
            args.load_from_checkpoint, lm=lm, tokenizer=tokenizer,
            device=args.accelerator)
    else:
        model_args = vars(args)
        model_args['lm'] = lm
        model_args['tokenizer'] = tokenizer
        keys = ['batch_size',
                'hidden_dim',
                'embedding_dim',
                'epochs',
                'finetune',
                'layers',
                'dropout',
                'lm',
                'tokenizer',
                'learning_rate',
                'loss',
                'mask_gaps',
                'multitask',
                'num_workers',
                'output_directory',
                'scheduler',
                'test_pairs',
                'train_pairs',
                'valid_pairs',
                'visualization_fraction',
                'shuffle_validation',
                'device']
        model_args = {k: model_args[k] for k in keys}
        model_args['device'] = args.accelerator
        model = DeepBLAST(**model_args)
    # profiler = AdvancedProfiler()

    # initialize Model Checkpoint Saver
    checkpoint_callback = ModelCheckpoint(
        dirpath=args.output_directory,
        monitor='validation_loss',
        filename="{epoch}-{step}-{validation_loss:0.4f}",
        verbose=True
    )

    trainer = Trainer(
        max_epochs=args.epochs,
        devices=args.devices,
        accelerator=args.accelerator,
        num_nodes=args.nodes,
        accumulate_grad_batches=args.grad_accum,
        gradient_clip_val=args.grad_clip,
        gradient_clip_algorithm="norm",
        precision=args.precision,
        # check_val_every_n_epoch=1,
        val_check_interval=1.0,
        fast_dev_run=False,
        strategy=DDPStrategy(find_unused_parameters=True),
        limit_train_batches=args.limit_train_batches,
        limit_val_batches=args.limit_val_batches,
        # strategy='ddp',
        # auto_scale_batch_size='power',
        # profiler=profiler,
        callbacks=[checkpoint_callback]
    )

    print('model', model)
    # model = torch.compile(model)  # pytorch2
    trainer.fit(model)
    # trainer.test()

    # In case this doesn't checkpoint
    torch.save(model.state_dict(),
               args.output_directory + '/last_ckpt.pt')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--accelerator', type=str, default=None)
    parser.add_argument('--devices', type=int, default=None)
    parser.add_argument('--grad-accum', type=int, default=1)
    parser.add_argument('--grad-clip', type=int, default=10)
    parser.add_argument('--nodes', type=int, default=1)
    parser.add_argument('--num-workers', type=int, default=1)
    parser.add_argument('--limit-train-batches', type=int, default=5000)
    parser.add_argument('--limit-val-batches', type=int, default=50)
    parser.add_argument('--precision', type=int, default=32)
    parser.add_argument('--load-from-checkpoint', type=str, default=None)
    # options include ddp_cpu, dp, ddp
    parser = DeepBLAST.add_model_specific_args(parser)
    hparams = parser.parse_args()
    main(hparams)
