# -*- coding: utf-8 -*-
"""
Created on Mon May 31 17:03:08 2021

@author: vorst
"""


# Python imports
import unittest
from typing import Tuple

# Third party imports
import numpy as np
import matplotlib.pyplot as plt

# Local imports
from pyMILES.embedding import embed_bag, embed_all_bags

# Declarations
N_NEGATIVE_BAGS = 20  # 20 total bags with negative labels
BAG_SIZE = 8  # 8 instances per bag
INSTANCE_SPACE = 2  # feature space of 2 per instance (2 dimensional data)
N_POSITIVE_BAGS = 20  # 20 total bags with positive labels

# %%


def generate_dummy_data() -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate some data which is 2-dimensional generated from (5) distinct probability distributions
    From the data, generate 20 positive and 20 negative bags
    A bag is labeled positive if it contains instances from at
    least two different distributions among N1, N2, and N3
    """
    # Create some dummy data
    # each instance is generated by one of the following two-dimensional 
    # probability distributions:
    # N1([5,5]^T, I), 
    # N2([5,-5]^T, I), 
    # N3([-5,5]^T, I),
    # N4([-5,-5]^T, I), 
    # N5([0,0]^T, I). N([5,5]^T, I) denotes the normal distribution with mean 
    # [5,5] and identity covariance matrix
    n = [([5, 5], [1, 1]),
         ([5, -5], [1, 1]),
         ([-5, 5], [1, 1]),
         ([-5, -5], [1, 1]),
         ([0, 0], [1, 1]), ]

    # Create 20 positive bags, and 20 negative bags
    # A bag is labeled positive if it contains instances from at 
    # least two different distributions among N1, N2, and N3
    positive_bags = np.zeros((N_POSITIVE_BAGS, BAG_SIZE, INSTANCE_SPACE))
    negative_bags = np.zeros((N_NEGATIVE_BAGS, BAG_SIZE, INSTANCE_SPACE))

    # Generate positive bags
    for i in range(0, N_POSITIVE_BAGS):

        # Fill with 2 instances from positive distribution
        distributions = np.random.randint(0, 100, BAG_SIZE)
        positive_bags[i, 0, :] = np.random.normal(n[distributions[0] % 3][0],  # Mean
                                                  # Standard Deviation
                                                  n[distributions[0] % 3][1],
                                                  INSTANCE_SPACE)  # Size
        positive_bags[i, 1, :] = np.random.normal(n[distributions[1] % 3][0],  # Mean
                                                  # Standard Deviation
                                                  n[distributions[1] % 3][1],
                                                  INSTANCE_SPACE)  # Size

        for j in range(2, BAG_SIZE):
            # Fill with instances from any other distribution
            positive_bags[i, j, :] = np.random.normal(n[distributions[j] % 5][0],  # Mean
                                                      # Standard Deviation
                                                      n[distributions[j] %
                                                          5][1],
                                                      INSTANCE_SPACE)  # Size

    # Generate negative bags
    for i in range(0, N_NEGATIVE_BAGS):
        # Fill with distributions, but maximum of 1 from n1,n2,n3
        distributions = np.random.randint(0, 100, BAG_SIZE)
        flag = False
        for j in range(0, BAG_SIZE):
            mod = distributions[j] % 5
            if flag:
                # Only allow a single instance from positive distribution
                mod = np.random.randint(3, 5)  # 3 or 4
            if mod in [0, 1, 2]:
                flag = True
            # Fill with instances from any other distribution
            negative_bags[i, j, :] = np.random.normal(n[mod][0],  # Mean
                                                      # Standard Deviation
                                                      n[mod][1],
                                                      INSTANCE_SPACE)  # Size
    return positive_bags, negative_bags

# %% Visualize bags


def visualize_2_dimensional(positive_bags: np.ndarray, negative_bags: np.ndarray) -> plt.Figure:
    """Visualize data from 2-dimensional arrays representing positive
    and negative instances from positive and negative bags"""

    fig, ax = plt.subplots(1, 1)
    ax.scatter(positive_bags.reshape(-1, 2)[:, 0],
               positive_bags.reshape(-1, 2)[:, 1],
               marker='o',
               label='positive instances')

    ax.scatter(negative_bags.reshape(-1, 2)[:, 0],
               negative_bags.reshape(-1, 2)[:, 1],
               marker='x',
               label='negative instances')

    ax.set_xlabel('x1')
    ax.set_ylabel('x2')
    ax.grid(True)
    ax.legend()
    # plt.show()

    return fig

# %% Apply mapping (embed bag to new space)


def embed_dummy_data(
        positive_bags: np.ndarray,
        negative_bags: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Representation of how to embed bags. In this case, all bags (both positive and negative)
    are embedded onto a single representative instance from the positive bags. There should
    be one instance in the embedded space which exactly matches the
    embeddding template (which is positive_bags[9])"""
    # C is the concept class - the set of all training instances
    # From positive and negative bags.
    # C = {x^k : k=1, ..., n}
    # Where x^k is the kth instance in the entire training set
    C = np.concatenate((
        np.reshape(positive_bags, (BAG_SIZE*N_POSITIVE_BAGS, INSTANCE_SPACE)),
        np.reshape(negative_bags, (BAG_SIZE*N_NEGATIVE_BAGS, INSTANCE_SPACE))),
        axis=0)

    # Create a sample embedded bag
    # Shape (320,) (n_instances * bag_size)
    embedded_bag = embed_bag(C, positive_bags[9], 3)
    print(f"Embedding a bag against a concept class which contains the bag should yield a similarity of 1: {embedded_bag}")

    # Embed all bags onto training instances
    bags = np.concatenate((positive_bags, negative_bags), axis=0)
    embedded_bags = embed_all_bags(C, bags, 3, distance='euclidean')

    return embedded_bags, bags

# %% Visualization - reduced bag space


def embed_and_visualize_embedded_data(bags: np.ndarray) -> plt.Figure:
    """Plot the 3-dimensional embedded data and visualize
    Note: The embedding is onto a 3-dimensional space. The basis of 3-D space
    is determined by vectors that are hard-coded, and have means close to the
    original distributions. This represents sample data that is representative of
    the dummy data"""

    # Embed the bags onto 3 instances for visualization
    _x1 = np.array([4.3, 5.2])
    _x2 = np.array([5.4, -3.9])
    _x3 = np.array([-6.0, 4.8])
    C = np.array((_x1, _x2, _x3))
    # Embed the bags again
    embedded_bags = embed_all_bags(C, bags, sigma=4.5)

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    ax.scatter(xs=embedded_bags[0, :20],
               ys=embedded_bags[1, :20],
               zs=embedded_bags[2, :20],
               marker='o',
               label='positive bags')

    ax.scatter(xs=embedded_bags[0, 20:],
               ys=embedded_bags[1, 20:],
               zs=embedded_bags[2, 20:],
               marker='x',
               label='negative bags')

    ax.set_xlabel('s(x^i,-)')
    ax.set_ylabel('s(x^j,-)')
    ax.set_zlabel('s(x^k,-)')
    ax.grid(True)
    ax.legend()
    # plt.show()

    return fig

# %% Put it all together, and what do you get?


def main():
    """Visualize embedding of bags onto 3 dimensional space"""
    
    # Generate some 2-dimensional data
    positive_bags, negative_bags = generate_dummy_data()

    # Visualize dummy data
    fig_dummy = visualize_2_dimensional(positive_bags, negative_bags)

    # Embed bags onto an example instance in the positive bags
    embedded_bags_self, bags = embed_dummy_data(positive_bags, negative_bags)

    # Embed bags onto a static embedding space (not onto one of the positive bags)
    fig_embedded = embed_and_visualize_embedded_data(bags)

    # Display figures
    plt.show()

    return None


if __name__ == '__main__':
    main()
