"""
Train file
"""
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms

from [=NAME=].model import Net


def train(
    model: nn.Module,
    train_loader: torch.utils.data.DataLoader,
    optimizer: optim,
    epoch: int,
) -> None:
    """
    Train function
    """
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )


def main() -> None:
    """
    Main function
    """
    epochs = 1
    learning_rate = 0.001
    gamma = 0.1

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    train_set = datasets.MNIST("data", train=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set)

    model = Net()
    optimizer = optim.Adadelta(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

    for epoch in range(1, epochs + 1):
        train(model, train_loader, optimizer, epoch)
        scheduler.step()

    torch.save(model.state_dict(), "results/models/mnist_cnn.pt")


if __name__ == "__main__":
    main()
