"""
https://github.com/jacobkimmel/pytorch_modelsize
"""
import torch
from torch.autograd import Variable
import numpy as np


class SizeEstimator(object):

    def __init__(self, model, input_size=(1, 1, 32, 32), bits=32):
        """
            Estimates the size of PyTorch models in memory
            for a given input size
        Args:
            model:
            input_size:
            bits:
        """
        self.model = model
        self.input_size = input_size
        self.bits = bits
        self.param_sizes = None
        self.out_sizes = None
        self.param_bits = None
        self.forward_backward_bits = None
        self.input_bits = None

    def get_parameter_sizes(self):
        """
        Get sizes of all parameters in model
        """
        mods = list(self.model.modules())
        sizes = []

        for i in range(1, len(mods)):
            m = mods[i]
            p = list(m.parameters())
            for j in range(len(p)):
                sizes.append(np.array(p[j].size()))

        self.param_sizes = sizes

    def get_output_sizes(self):
        """
        Run sample input through each layer to get output sizes
        """
        input_ = Variable(torch.FloatTensor(*self.input_size), volatile=True)
        mods = list(self.model.modules())
        out_sizes = []
        for i in range(1, len(mods)):
            m = mods[i]
            out = m(input_)
            out_sizes.append(np.array(out.size()))
            input_ = out

        self.out_sizes = out_sizes
        return

    def calc_param_bits(self):
        """
        Calculate total number of bits to store `model` parameters
        """
        total_bits = 0
        for i in range(len(self.param_sizes)):
            s = self.param_sizes[i]
            bits = np.prod(np.array(s)) * self.bits
            total_bits += bits
        self.param_bits = total_bits

    def calc_forward_backward_bits(self):
        """
        Calculate bits to store forward and backward pass
        """
        total_bits = 0
        for i in range(len(self.out_sizes)):
            s = self.out_sizes[i]
            bits = np.prod(np.array(s)) * self.bits
            total_bits += bits
        # multiply by 2 for both forward AND backward
        self.forward_backward_bits = (total_bits * 2)

    def calc_input_bits(self):
        """
        Calculate bits to store input
        """
        self.input_bits = np.prod(np.array(self.input_size)) * self.bits

    def estimate_size(self):
        """
        Estimate model size in memory in megabytes and bits
        Returns:
            tuple: (size in megabytes, size in bits)
        """
        self.get_parameter_sizes()
        self.get_output_sizes()
        self.calc_param_bits()
        self.calc_forward_backward_bits()
        self.calc_input_bits()
        total = self.param_bits + self.forward_backward_bits + self.input_bits

        total_megabytes = (total / 8) / (1024 ** 2)
        return total_megabytes, total
