import torch
from torch import nn


class DilatedBlock(nn.Module):

    def __init__(self, config, input_size, kernel_size=3):
        super(DilatedBlock, self).__init__()
        self.config = config
        self.layers = self.config.layers
        encoder_layers = [nn.Conv2d(input_size,
                                    self.config.feature_size,
                                    kernel_size=kernel_size,
                                    bias=True,
                                    padding=dilation,
                                    dilation=dilation)
                          for dilation in self.layers]
        self.encoder_layers = nn.ModuleList(encoder_layers)
        # self.batch_normal = nn.BatchNorm2d(self.config.feature_size*len(self.layers))

    def forward(self, inputs):
        encoders = []
        for layer in self.encoder_layers:
            encoders.append(layer(inputs))
        output_features = torch.cat(encoders, dim=1)
        return output_features
