# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/capsule.ipynb (unless otherwise specified).

__all__ = ['to', 'train', 'validate', 'predict', 'Capsule']

# Cell
from ..loader import *
from ..torch_loader import *
from ..paths import loaddill, dumpdill

def to(item, device):
    if item is None:
        return None
    elif isinstance(item, (torch.Tensor, nn.Module)):
        return item.to(device)
    elif isinstance(item, dict):
        return {k: to(v, device) for k, v in item.items()}
    elif isinstance(item, (list, tuple)):
        return [to(_item, device) for _item in item]
    else:
        raise NotImplementedError(f"function is not implemented for {type(item)}")


def train(train_function):
    def _train_batch(self, data):
        data = self.before_train_batch(data)
        outputs = train_function(self, data)
        outputs = self.after_train_batch(outputs)
        assert isinstance(outputs, dict)
        return outputs

    return _train_batch


@torch.no_grad()
def validate(validation_function):
    def _validate_batch(self, data):
        data = self.before_validate_batch(data)
        outputs = validation_function(self, data)
        outputs = self.after_validate_batch(outputs)
        assert isinstance(outputs, dict)
        return outputs

    return _validate_batch


@torch.no_grad()
def predict(predict_function):
    def _predict(self, data):
        data = self.before_predict(data)
        outputs = predict_function(self, data)
        return outputs

    return _predict


class Capsule(nn.Module):
    def __init__(self, report=None):
        super().__init__()
        if report is not None:
            self.report = loaddill(report)

    # Train Utils
    def before_train_batch(self, data):
        self.train()
        self.optimizer.zero_grad()
        data = to(data, getattr(self, "device", "cuda"))
        return data

    def after_train_batch(self, outputs):
        outputs["loss"].backward()
        self.optimizer.step()
        return outputs

    # Validation Utils
    def before_validate_batch(self, data):
        self.eval()
        data = to(data, getattr(self, "device", "cuda"))
        return data

    def after_validate_batch(self, outputs):
        return outputs

    def before_predict(self, data):
        self.eval()
        data = to(data, getattr(self, "device", "cuda"))
        return data

    def after_predict(self, outputs):
        return outputs

    def load(self, weights_path=None, device="cpu"):
        if weights_path:
            load_torch_model_weights_to(self, weights_path, device=device)
        try:
            weights_path = weights_path + '.report'
            self.report = loaddill(weights_path)
        except:
            pass

    def save(self, save_to):
        save_torch_model_weights_from(self, save_to)
        save_to = save_to + ".report"
        dumpdill(self.report, save_to)

    # Fit function
    def fit(self, trn_dl=None, val_dl=None, num_epochs=1, device="cuda", save_to=None):
        if not hasattr(self, "report"):
            self.report = Report(num_epochs)
        else:
            self.report = Report(num_epochs, old_report=self.report)

        self.device = device
        to(self, self.device)

        try:
            for epoch in range(num_epochs):
                self.report.n_epochs = num_epochs
                if trn_dl is not None:
                    N = len(trn_dl)
                    for ix, data in enumerate(trn_dl):
                        loss = self.train_batch(data)
                        self.report.record(pos=(epoch + (ix + 1) / N), **loss, end="\r")
                if val_dl is not None:
                    self.evaluate(val_dl, report=self.report, device=device, epoch=epoch)
                self.report.report_avgs(epoch + 1)
        except KeyboardInterrupt:
            pass

        self.report.plot(log=True, smooth=10)
        if save_to:
            self.save(save_to)

    def evaluate(self, val_dl, report=None, device="cuda", epoch=None):
        if report is None:
            show_report = True
            report = Report(1)
            epoch = 0
        else:
            show_report = False
            epoch = epoch
        self.device = device
        to(self, self.device)

        N = len(val_dl)
        for ix, data in enumerate(val_dl):
            loss = self.validate_batch(data)
            report.record(pos=(epoch + (ix + 1) / N), **loss, end="\r")

        if show_report:
            report.report_avgs(1)
            report.plot(log=True, smooth=3)

    @train
    def train_batch(self, data):
        x, y = data
        _y = self(x)
        loss = self.criterion(_y, y)
        return {"loss": loss}

    @validate
    def validate_batch(self, data):
        x, y = data
        output = self(x)
        loss = self.criterion(output, y)
        return {"loss": loss}

    @predict
    def predict_batch(self, data):
        x, _ = data
        outputs = self(x)
        return outputs

