import numpy as np
import scipy.stats as ss

def add_noise(X, noise, seed = None):
    """Adds Gaussian noise to a point cloud."""
    np.random.seed(seed = seed)
    X_noise = np.random.normal(scale = noise, size = X.shape)
    return X + X_noise


# ===============================================================
#  ----------------------- SHAPES IN 2D ------------------------
# ===============================================================

def two_adjacent_circles(N, 
                    r = 1, 
                    noise = 0,
                    seed = None):
    """Generates sample points of two circles with \
    radius `r` and `1-r`."""
    
    np.random.seed(seed = seed)
    s,t = np.random.uniform(0, 2*np.pi, [2, N // 2])
    
    x = np.hstack([r*np.cos(s) + r, (1-r)*np.cos(t) + (1+r)])
    y = np.hstack([r*np.sin(s), (1-r)*np.sin(t)])
    
    X = np.vstack((x,y)).T
    X_noise = add_noise(X, noise, seed = seed)
    return X_noise

def circle(N, 
        noise = 0,
        seed = None):
    """Generates sample points of circle with \
    radius 1 and gaussian noise with of scale `scale`."""
    
    np.random.seed(seed=seed)
    
    t = np.random.uniform(0, 2*np.pi, size = N)
    
    x = np.cos(t)
    y = np.sin(t)
    
    X = np.vstack((x,y)).T
    X_noise = add_noise(X, noise, seed = seed)
    return X_noise


def vonmises_circles(N, 
                kappa = 1, 
                noise = 0,
                seed = None):
    """Generates von Mises distributed sample points on \
    the of circle."""
    
    np.random.seed(seed = seed)
    
    if kappa == 0:
        t = np.random.uniform(0, 2*np.pi, size = N)
    else:
        t = ss.vonmises(kappa = kappa).rvs(size = N)
        
    noise_x, noise_y = np.random.normal(scale = noise, size = [2, N])
    
    x = np.cos(t) + noise_x
    y = np.sin(t) + noise_y
    
    X = np.vstack((x,y)).T
    X_noise = add_noise(X, noise, seed = seed)
    return X_noise


def annulus(N, 
        r = 1,
        noise = 0, 
        seed = None):
    """Generates sample points of a 2-dimensional \
    Annulus with inner radius `r`. Outside radius 
    is taken to be 1."""
    
    np.random.seed(seed = seed)
    u,v = np.random.uniform(0, 1, [2,N])
    
    phi = 2*np.pi*u
    r = np.sqrt((1-r**2)*v + r**2)
    
    x = r*np.cos(phi)
    y = r*np.sin(phi)
    
    X = np.vstack((x,y)).T
    X_noise = add_noise(X, noise, seed = seed)
    return X_noise


def sinoidal_trajectory(N, shift, 
                    noise = 0, 
                    seed = None):
    """Generates sample points of a 2-dimensional \
    ellipse generated by sin(t) and sin(t+s)."""
    t = np.random.uniform(0, 2*np.pi, N)
    x = np.sin(t)
    y = np.sin(t - shift)
    
    X = np.vstack((x,y)).T
    X_noise = add_noise(X, noise, seed = seed)
    return X_noise


# ===============================================================
#  -------------------- SHAPES IN 3D  --------------------------
# ===============================================================

def cylinder(N, height,
        noise = 0,
        seed = None):
    """Generates sample points of a cylinder in 3D \
    with unit radius and height `height`."""
    np.random.seed(seed=seed)
    u,v = np.random.uniform(0, 1, [2,N])
    
    phi = 2*np.pi*u
    
    x = np.cos(phi)
    y = np.sin(phi)
    z = np.random.uniform(0, height, N)
    
    X = np.vstack((x,y,z)).T
    X_noise = add_noise(X, noise, seed = seed)
    return X_noise


def torus(N, r,
        noise = 0,
        seed = None):
    """Generates sample points of a cylinder in 3D \
    with radius of revolution being 1 and outer radius 'r'"""
    
    np.random.seed(seed=seed)
    u,v = np.random.uniform(0, 2*np.pi, [2,N])
    
    x = (1 + r*np.cos(v))*np.cos(u)
    y = (1 + r*np.cos(v))*np.sin(u)
    z = r*np.sin(v)
    
    X = np.vstack((x,y,z)).T
    X_noise = add_noise(X, noise, seed = seed)
    return X_noise


# ===============================================================
#  ---------------- N-DIMENSIONAL SHAPES  ----------------------
# ===============================================================
def box(N,
        dim = 2, 
        noise = 0,
        seed = None):
    """Generates sample points of a unit-box in dimenseion `dim`."""
    
    np.random.seed(seed=seed)
    X = np.random.uniform(0, 2*np.pi, [N, dim])
    X_noise = add_noise(X, noise, seed = seed)
    return X_noise

def sphere(N,
        dim = 2, 
        noise = 0,
        seed = None):
    """Generates sample points of a unit sphere in dimenseion `dim`.
    
    CAUTION: The parameter `dim` referes to the embedding dimension, 
    not the intrinsic dimension of the sphere!
    """
    
    np.random.seed(seed=seed)
    pre_X = np.random.normal(size = (N, dim))
    X = (pre_X.T / np.linalg.norm(pre_X, axis=1)).T
    X_noise = add_noise(X, noise, seed = seed)
    return X_noise