# -*- coding: utf-8 -*-
"""
Created on Mon May 31 17:03:08 2021

@author: vorst
"""


# Python imports

# Third party imports
import numpy as np
import matplotlib.pyplot as plt

# Local imports
from embedding import embed_bag, embed_all_bags 


#%%

# Create some dummy data
"""each instance is generated by one of the following two-dimensional 
probability distributions: N1([5,5]^T, I), 
N2([5,-5]^T, I), 
N3([-5,5]^T, I),
N4([-5,-5]^T, I), 
N5([0,0]^T, I). N([5,5]^T, I) denotes the normal distribution with mean 
[5,5] and identity covariance matrix"""

n = [([5,5], [1,1]),
     ([5,-5], [1,1]),
     ([-5,5], [1,1]),
     ([-5,-5], [1,1]),
     ([0,0], [1,1]),]

# Create 20 positive bags, and 20 negative bags
"""A bag is labeled positive if it contains instances from at 
least two different distributions among N1, N2, and N3"""
BAG_SIZE = 8
INSTANCE_SPACE = 2
N_POSITIVE_BAGS = 20
positive_bags = np.zeros((N_POSITIVE_BAGS, BAG_SIZE, INSTANCE_SPACE))
N_NEGATIVE_BAGS = 20
negative_bags = np.zeros((N_NEGATIVE_BAGS, BAG_SIZE, INSTANCE_SPACE))

for i in range(0, N_POSITIVE_BAGS):
    
    # Fill with 2 instances from positive distribution    
    distributions = np.random.randint(0, 100, BAG_SIZE)
    positive_bags[i,0,:] = np.random.normal(n[distributions[0] % 3][0], # Mean 
                                            n[distributions[0] % 3][1], # Standard Deviation
                                            INSTANCE_SPACE) # Size
    positive_bags[i,1,:] = np.random.normal(n[distributions[1] % 3][0], # Mean 
                                            n[distributions[1] % 3][1], # Standard Deviation
                                            INSTANCE_SPACE) # Size
    
    for j in range(2, BAG_SIZE):
        # Fill with instances from any other distribution
        positive_bags[i,j,:] = np.random.normal(n[distributions[j] % 5][0], # Mean 
                                                n[distributions[j] % 5][1], # Standard Deviation
                                                INSTANCE_SPACE) # Size


for i in range(0, N_NEGATIVE_BAGS):
    # Fill with distributions, but maximum of 1 from n1,n2,n3
    distributions = np.random.randint(0, 100, BAG_SIZE)
    flag = False
    for j in range(0, BAG_SIZE):
        mod = distributions[j] % 5
        if flag:
            # Only allow a single instance from positive distribution
            mod = np.random.randint(3, 5) # 3 or 4
        if mod in [0,1,2]:
            flag = True
        # Fill with instances from any other distribution
        negative_bags[i,j,:] = np.random.normal(n[mod][0], # Mean 
                                                n[mod][1], # Standard Deviation
                                                INSTANCE_SPACE) # Size
        
        
#%% Visualize bags

fig, ax = plt.subplots(1,1)
ax.scatter(positive_bags.reshape(-1, 2)[:,0], 
           positive_bags.reshape(-1, 2)[:,1],
           marker='o',
           label='positive instances')

ax.scatter(negative_bags.reshape(-1, 2)[:,0], 
           negative_bags.reshape(-1, 2)[:,1],
           marker='x',
           label='negative instances')

ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.grid(True)
ax.legend()
plt.show()

#%% Apply mapping (embed bag to new space)
# C is the concept class - the set of all training instances
# From positive and negative bags.
# C = {x^k : k=1, ..., n}
# Where x^k is the kth instance in the entire training set
C = np.concatenate((
    np.reshape(positive_bags, (BAG_SIZE*N_POSITIVE_BAGS, INSTANCE_SPACE)),
    np.reshape(negative_bags, (BAG_SIZE*N_NEGATIVE_BAGS, INSTANCE_SPACE))), 
    axis=0)

# Create a sample embedded bag
embedded_bag = embed_bag(C, positive_bags[9], 3) # Shape (320,) (n_instances * bag_size)

# Embed all bags onto training instances
bags = np.concatenate((positive_bags, negative_bags), axis=0)
embedded_bags = embed_all_bags(C, bags, 3)

#%% Visualization - reduced bag space

# Embed the bags onto 3 instances for visualization
_x1 = np.array([4.3, 5.2])
_x2 = np.array([5.4, -3.9])
_x3 = np.array([-6.0, 4.8])
C = np.array((_x1,_x2,_x3))
embedded_bags = embed_all_bags(C, bags, sigma=4.5)
    
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(xs=embedded_bags[0,:20], 
           ys=embedded_bags[1,:20],
           zs=embedded_bags[2,:20],
           marker='o',
           label='positive bags')

ax.scatter(xs=embedded_bags[0,20:], 
           ys=embedded_bags[1,20:],
           zs=embedded_bags[2,20:],
           marker='x',
           label='negative bags')

ax.set_xlabel('s(x^i,-)')
ax.set_ylabel('s(x^j,-)')
ax.set_zlabel('s(x^k,-)')
ax.grid(True)
ax.legend()
plt.show()



