"""
This module defines a generic trainer for simple models and datasets.
"""

# Externals
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel

# Locals
from .base import BaseTrainer
from models import get_model

class GenericTrainer(BaseTrainer):
    """Trainer code for basic classification problems."""
    
    def __init__(self, **kwargs):
        super(GenericTrainer, self).__init__(**kwargs)
        
    def build_model(self, model_type='resnet', loss='CE',
                    optimizer='SGD', learning_rate=0.01,
                    momentum=0.9, **model_args):
        """Instantiate our model"""
        self.loss = loss
        # Construct the model
        self.model = get_model(name=model_type, **model_args).to(self.device)
        
        # Distributed data parallelism
        if self.distributed:
            device_ids = [self.gpu] if self.gpu is not None else None
            self.model = DistributedDataParallel(self.model, device_ids=device_ids)
            
        # TODO: add support for more optimizers and loss functions here
        opt_type = dict(SGD=torch.optim.SGD)[optimizer]
        self.optimizer = opt_type(self.model.parameters(), lr=learning_rate, momentum=momentum)
        loss_type = dict(CE=torch.nn.CrossEntropyLoss,BCE=torch.nn.BCEWithLogitsLoss)[loss]
        self.loss_func = loss_type()
        
    def train_epoch(self, data_loader):
        """Train for one epoch"""
        self.model.train()
        sum_loss = 0
        sum_correct = 0
        # Loop over training batches
        for i, (batch_input, batch_target) in enumerate(data_loader):
            batch_input = batch_input.to(self.device)
            if self.loss=='BCE' and batch_target.dim()==1:
                batch_target = batch_target.float().unsqueeze(1)
            batch_target = batch_target.to(self.device)
            self.model.zero_grad()
            batch_output = self.model(batch_input)
            batch_loss = self.loss_func(batch_output, batch_target)
            batch_loss.backward()
            self.optimizer.step()
            loss = batch_loss.item()
            sum_loss += loss
            n_correct = self.accuracy(batch_output, batch_target)
            sum_correct += n_correct
            self.logger.debug(' batch {:>3}/{:<3} | {:6,} samples | Loss {:.5f} | Accuracy {:6.2f}'
                              .format(i+1, len(data_loader), len(batch_input), loss, 100*n_correct/len(batch_input)))
        train_loss = sum_loss / (i + 1)
        train_acc = sum_correct / len(data_loader.sampler)
        self.logger.debug('{:>14} | {:6,} samples | Loss {:.5f} | Accuracy {:6.2f}'
                          .format('Training', len(data_loader.sampler), train_loss, 100*train_acc))
        return dict(train_loss=train_loss)
    
    @torch.no_grad()
    def evaluate(self, data_loader, mode):
        """"Evaluate the model"""
        self.model.eval()
        sum_loss = 0
        sum_correct = 0
        # Loop over batches
        for i, (batch_input, batch_target) in enumerate(data_loader):
            batch_input = batch_input.to(self.device)
            if self.loss=='BCE' and batch_target.dim()==1:
                batch_target = batch_target.float().unsqueeze(1)
            batch_target = batch_target.to(self.device)
            batch_output = self.model(batch_input)
            loss = self.loss_func(batch_output, batch_target).item()
            sum_loss += loss
            n_correct = self.accuracy(batch_output, batch_target)
            sum_correct += n_correct
        valid_loss = sum_loss / (i + 1)
        valid_acc = sum_correct / len(data_loader.sampler)
        self.logger.debug('{:>14} | {:6,} samples | Loss {:.5f} | Accuracy {:6.2f}'
                          .format(mode, len(data_loader.sampler), valid_loss, 100*valid_acc))
        return dict(valid_loss=valid_loss, valid_acc=valid_acc)
    
    def accuracy(self, batch_output, batch_target):
        # Count number of correct predictions
        if self.loss=='BCE':
            batch_preds = (torch.sigmoid(batch_output)>0.5).float()
            if batch_preds.dim()==1:
                n_correct = batch_preds.eq(batch_target).float().sum()
            else:
                n_correct = batch_preds.eq(batch_target).all(dim=1).float().sum()
        else:
            _, batch_preds = torch.max(batch_output, 1)
            n_correct = batch_preds.eq(batch_target).sum().item()
        return n_correct
    
def get_trainer(**kwargs):
    """
    Test
    """
    return GenericTrainer(**kwargs)

def _test():
    t = GenericTrainer(output_dir='./')
    t.build_model()
    
