"""
Model file
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class Net(nn.Module):
    """
    NN model
    """

    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x_input: Tensor) -> Tensor:
        """
        Defines the computation performed at every call.
        """
        output = self.conv1(x_input)
        output = F.relu(output)
        output = self.conv2(output)
        output = F.relu(output)
        output = F.max_pool2d(output, 2)
        output = self.dropout1(output)
        # pylint: disable=E1101
        output = torch.flatten(output, 1)
        # pylint: enable=E1101
        output = self.fc1(output)
        output = F.relu(output)
        output = self.dropout2(output)
        output = self.fc2(output)
        output = F.log_softmax(output, dim=1)
        return output
