import math
import matplotlib.pyplot as plt
from .Generaldistribution import Distribution


class Binomial(Distribution):
    """ Binomial distribution class for calculating and 
    visualizing a Binomial distribution.
    Attributes:
        mean (float) representing the mean value of the distribution
        stdev (float) representing the standard deviation of the distribution
        data_list (list of floats) a list of floats to be extracted from the data file
        p (float) representing the probability of an event occurring
        n (int) the total number of trials
    """
    #       A binomial distribution is defined by two variables: 
    #           the probability of getting a positive outcome
    #           the number of trials
    #           mean = p * n
    #           standard deviation = sqrt(n * p * (1 - p))
    
    def __init__(self, prob=.5, size=20):
        
        # store the probability of the distribution in an instance variable p
        self.p = prob
        #  store the size of the distribution in an instance variable n
        self.n = size
        # use the init function from the Distribution class to initialize the mean and the standard deviation
        # of the distribution
        Distribution.__init__(self, self.calculate_mean(), self.calculate_stdev())
    
    def calculate_mean(self):
        
        """Function to calculate the mean from p and n
        Args: 
            None
        Returns: 
            float: mean of the data set
        """
        # mean = p * n
        self.mean = self.p * self.n
        
        return self.mean
    
    def calculate_stdev(self):
        
        """Function to calculate the standard deviation from p and n.
        Args: 
            None
        Returns: 
            float: standard deviation of the data set
        """
        # standard deviation = sqrt(n * p * (1 - p))
        self.stdev = math.sqrt(self.n * self.p * (1 - self.p))
        
        return self.stdev
    
    def replace_stats_with_data(self):
        
        """Function to calculate p and n from the data set
        Args: 
            None
        Returns: 
            float: the p value
            float: the n value
        """
        
        # The read_data_file() from the Generaldistribution class can read in a data file.
        # Because the Binomaildistribution class inherits from the Generaldistribution class, not necessary to re-write
        # this method. However, the method doesn't update the mean or standard deviation of a distribution. Hence
        # another method that calculates n, p, mean and standard deviation from a data set and then updates the n, p,
        # mean and stdev attributes.
        # updates the n attribute of the binomial distribution
        self.n = len(self.data)
        # updates the p value of the binomial distribution by calculating the number of positive trials divided by
        # the total trials
        self.p = 1.0 * sum(self.data) / len(self.data)
        # updates the mean attribute updates the standard deviation attribute
        self.mean = self.calculate_mean()
        self.stdev = self.calculate_stdev()

        return self.p, self.n
   
    
    def plot_bar(self):
        """Function to output a histogram of the instance variable data using matplotlib pyplot library.
        Args:
            None
        Returns:
            None
        """
        # Uses the matplotlib package to plot a bar chart of the data
        # The x-axis should have the value zero or one
        # The y-axis should have the count of results for each case
        plt.bar(x=['0', '1'], height=[(1 - self.p) * self.n, self.p * self.n])
        plt.title('Bar Chart of Data')
        plt.xlabel('outcome')
        plt.ylabel('count')
    
    def pdf(self, k):
        """Probability density function calculator for the gaussian distribution.
        Args:
            k (float): point for calculating the probability density function
        Returns:
            float: probability density function output
        """
        # Calculate the probability density function for a binomial distribution
        a = math.factorial(self.n) / (math.factorial(k) * (math.factorial(self.n - k)))
        b = (self.p ** k) * (1 - self.p) ** (self.n - k)
        
        return a * b
    
    def plot_bar_pdf(self):
        
        """Function to plot the pdf of the binomial distributio
        Args:
            None
        Returns:
            list: x values for the pdf plot
            list: y values for the pdf plot
        """
        
        # Use a bar chart to plot the probability density function from k = 0 to k = n
        x = []
        y = []
        # calculate the x values to visualize
        for i in range(self.n + 1):
            x.append(i)
            y.append(self.pdf(i))
        # make the plots
        plt.bar(x, y)
        plt.title('Distribution of Outcomes')
        plt.ylabel('Probability')
        plt.xlabel('Outcome')
        plt.show()
        
        return x, y
    
    def __add__(self, other):
        
        """Function to add together two Binomial distributions with equal p
        Args:
            other (Binomial): Binomial instance
        Returns:
            Binomial: Binomial distribution
        """
        
        try:
            assert self.p == other.p, 'p values are not equal'
        except AssertionError as error:
            raise
        
        # Define addition for two binomial distributions.
        try:
            assert self.p == other.p, 'p values are not equal'
        except AssertionError as error:
            raise
        
        result = Binomial()
        result.n = self.n + other.n
        result.p = self.p
        result.calculate_mean()
        result.calculate_stdev()
        
        return result
    
    def __repr__(self):
        
        """Function to output the characteristics of the Binomial instance
        Args:
            None
        Returns:
            string: characteristics of the Gaussian
        """
        # Define the representation method so that the output looks like mean 5, standard deviation 4.5, p .8, n 20
        
        return "mean {}, standard deviation {}, p {}, n {}". \
            format(self.mean, self.stdev, self.p, self.n)
