from torch.nn import (
    AvgPool2d,
    Conv2d,
    CrossEntropyLoss,
    MSELoss,
    Dropout,
    Linear,
    MaxPool2d,
    ReLU,
    Sigmoid,
    Tanh,
    ZeroPad2d,
)

from .avgpool2d import AvgPool2DDerivatives
from .conv2d import Conv2DDerivatives
from .crossentropyloss import CrossEntropyLossDerivatives
from .mseloss import MSELossDerivatives
from .dropout import DropoutDerivatives
from .linear import LinearDerivatives
from .maxpool2d import MaxPool2DDerivatives
from .relu import ReLUDerivatives
from .sigmoid import SigmoidDerivatives
from .tanh import TanhDerivatives
from .zeropad2d import ZeroPad2dDerivatives

derivatives_for = {
    Linear: LinearDerivatives,
    Conv2d: Conv2DDerivatives,
    AvgPool2d: AvgPool2DDerivatives,
    MaxPool2d: MaxPool2DDerivatives,
    ZeroPad2d: ZeroPad2dDerivatives,
    Dropout: DropoutDerivatives,
    ReLU: ReLUDerivatives,
    Tanh: TanhDerivatives,
    Sigmoid: SigmoidDerivatives,
    CrossEntropyLoss: CrossEntropyLossDerivatives,
    MSELoss: MSELossDerivatives,
}
