"""
Validate file
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torchvision import datasets, transforms

from [=NAME=].model import Net


def validate(model: nn.Module, val_loader: torch.utils.data.DataLoader) -> None:
    """
    Validate function
    """
    model.eval()
    test_loss = 0.0
    correct = 0.0
    with torch.no_grad():
        for data, target in val_loader:
            output = model(data)
            test_loss += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(val_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss,
            correct,
            len(val_loader.dataset),
            100.0 * correct / len(val_loader.dataset),
        )
    )


def main() -> None:
    """
    Main function
    """
    model = Net()
    model.load_state_dict(torch.load("results/models/mnist_cnn.pt"))

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    validate_set = datasets.MNIST("data", train=False, transform=transform)
    val_loader = torch.utils.data.DataLoader(validate_set)

    validate(model, val_loader)


if __name__ == "__main__":
    main()
