import numpy as np

from tqdm import tqdm
from .abc_interpreter import InputGradientInterpreter
from ..data_processor.readers import preprocess_inputs, preprocess_save_path
from ..data_processor.visualizer import explanation_to_vis, show_vis_explanation, save_image


class SmoothGradInterpreter(InputGradientInterpreter):
    """
    Smooth Gradients Interpreter.

    Smooth Gradients method solves the problem of meaningless local variations in partial derivatives
    by adding random noise to the inputs multiple times and take the average of the
    gradients.

    More details regarding the Smooth Gradients method can be found in the original paper:
    http://arxiv.org/pdf/1706.03825.pdf
    """

    def __init__(self,
                 paddle_model,
                 use_cuda=None,
                 device='gpu:0',
                 model_input_shape=[3, 224, 224]):
        """
        Initialize the SmoothGradInterpreter.

        Args:
            paddle_model (callable): A paddle model that outputs predictions.
            use_cuda (bool, optional): Whether or not to use cuda. Default: None
            model_input_shape (list, optional): The input shape of the model. Default: [3, 224, 224]
        """
        
        InputGradientInterpreter.__init__(self, paddle_model, device, use_cuda)
        
        self.model_input_shape = model_input_shape
        self.data_type = 'float32'

    def interpret(self,
                  inputs,
                  labels=None,
                  noise_amount=0.1,
                  n_samples=50,
                  visual=True,
                  save_path=None):
        """
        Main function of the interpreter.

        Args:
            inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy array of read images.
            labels (list or tuple or numpy.ndarray, optional): The target labels to analyze. The number of labels should be equal to the number of images. If None, the most likely label for each image will be used. Default: None
            noise_amount (float, optional): Noise level of added noise to the image.
                                            The std of Guassian random noise is noise_amount * (x_max - x_min). Default: 0.1
            n_samples (int, optional): The number of new images generated by adding noise. Default: 50
            visual (bool, optional): Whether or not to visualize the processed image. Default: True
            save_path (str or list of strs or None, optional): The filepath(s) to save the processed image(s). If None, the image will not be saved. Default: None

        :return: interpretations/gradients for each image
        :rtype: numpy.ndarray
        """

        imgs, data = preprocess_inputs(inputs, self.model_input_shape)
        # print(imgs.shape, data.shape, imgs.dtype, data.dtype)  # (1, 224, 224, 3) (1, 3, 224, 224) uint8 float32

        bsz = len(data)
        data_type = np.array(data).dtype
        self.data_type = data_type

        self._build_predict_fn(gradient_of='probability')

        # obtain the labels (and initialization).
        if labels is None:
            _, preds = self.predict_fn(data, None)
            labels = preds
        labels = np.array(labels).reshape((bsz, ))

        # SmoothGrad
        max_axis = tuple(np.arange(1, data.ndim))
        stds = noise_amount * (
            np.max(data, axis=max_axis) - np.min(data, axis=max_axis))

        total_gradients = np.zeros_like(data)
        for i in tqdm(range(n_samples), leave=False, position=1):
            noise = np.concatenate([
                np.float32(
                    np.random.normal(0.0, stds[j], (1, ) + tuple(d.shape)))
                for j, d in enumerate(data)
            ])
            data_noised = data + noise
            gradients, _ = self.predict_fn(data_noised, labels)
            total_gradients += gradients

        avg_gradients = total_gradients / n_samples

        # visualize and save image.
        if save_path is None and not visual:
            # no need to visualize or save explanation results.
            pass
        else:
            save_path = preprocess_save_path(save_path, bsz)
            for i in range(bsz):
                # print(imgs[i].shape, avg_gradients[i].shape)
                vis_explanation = explanation_to_vis(imgs[i], np.abs(avg_gradients[i]).sum(0), style='overlay_grayscale')
                if visual:
                    show_vis_explanation(vis_explanation)
                if save_path[i] is not None:
                    save_image(save_path[i], vis_explanation)

        # intermediate results, for possible further usages.
        self.labels = labels

        return avg_gradients
