import torch
from collections import Counter

__all__ = [
    'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
    'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6',
    'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink',
    'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin',
    'Tanhshrink', 'RReLU', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
    'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
    'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
    'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
    'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
    'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
    'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
    'LSTMCell', 'GRUCell', 'PixelShuffle', 'PixelUnshuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d',
    'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
    'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
    'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
    'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
    'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
    'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
    'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
    'LazyBatchNorm1d', 'LazyBatchNorm2d', 'LazyBatchNorm3d',
    'LazyInstanceNorm1d', 'LazyInstanceNorm2d', 'LazyInstanceNorm3d',
    'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'Mish', 'TripletMarginWithDistanceLoss', 'ChannelShuffle'
]

def parameter_summary(model,_print:bool=True,border:bool=False)->tuple:
    '''
    Args: 
        model: PyTorch model
        border: Seperation line after printing out 
          the details of each layer, default = True
        _print: default==True , if set to False
          it won't print the summary, it will just 
          return the number of parameters (values)
    Returns: 
        A tuple
        (Total-TRAINABLE-params, total-params, total-NON-trainable-params)
    '''
    total_params = 0
    non_trainable=0
    s="{:<20}     {:^20} {}  {:>20} {:^30}\n".format('LAYER TYPE','KERNEL SHAPE',
     '#parameters',' (weights+bias)','requires_grad')
    s += "_"*100 + "\n"
    index=1
    for  i in model.modules():
        if i._get_name() in __all__:
            if border: s += "_"*100 + "\n"
            layer=i._get_name()
            if bool(i._parameters.keys()):
                weight_shape = torch.tensor(i._parameters['weight'].shape)
                _weight_shape = list(i._parameters['weight'].shape)
                x=torch.prod(weight_shape).item()
                Wgrad = i._parameters['weight'].requires_grad
                if not Wgrad: non_trainable+=x

                if i._parameters['bias'] is not None:
                    bias = list(i._parameters['bias'].shape)
                    total_params += x+bias[0]
                    Bgrad = i._parameters['bias'].requires_grad
                    if not Bgrad: non_trainable+=bias[0]
                    
                    s += " {:<20}   {:^20}\t{:,}  {:>25} {:^30}\n".format(layer+'-'+str(index),
                    str(_weight_shape),x+bias[0],f'({x} + {bias[0]})',f'{Wgrad} {Bgrad}')
                else:
                    s += " {:<20}   {:^20}\t{:,}  {:>25} {:^30}\n".format(layer+'-'+str(index),str(_weight_shape),x,f'({x}+0)',str(Wgrad))
                    total_params+= x
            else:   
                 
                s += " {:<20}   {:^20}\t{:}  {:>25} {:^30}\n".format(layer+'-'+str(index),'-','-','-','')
            index+=1
    s += "="*100 +"\n"

    if _print:
        print(s)           
        print('Total parameters {:,}'.format(total_params))
        print('Total Non-Trainable parameters {:,}'.format(non_trainable))
        print('Total Trainable parameters {:,}'.format(total_params-non_trainable))
    return (total_params-non_trainable,total_params,non_trainable)


def get_num_layers(model)->dict:
    l =[]
    for i in model.modules():
        name=i._get_name()
        if name in __all__:
            l.append(name)
    return dict(Counter(l))