#!/usr/bin/env python

# generate random vectors based on different algorithms
# Note that the two functions
# gen_normal_random_spin and gen_Gaussian_random_spin
# are essentially differnt
# you have to understand the difference before using them

import numpy as np

# for multiple spins
# gen_normal_random_spin can generate 
# random spins uniformly distributed on the 2-sphere
# while 
# gen_sphereical_random_spin lead to accumulation near the poles
# gen_cubic_spin lead to accumulation near the cube diagonal directions
# we recommand you to use gen_normal_random_spin


def gen_spherical_random_spin(n=1):
    theta = np.random.random(size=n)*np.pi
    phi = np.random.random(size=n)*2*np.pi
    spins = np.array([np.cos(phi*2)*np.sin(theta),np.sin(phi*2)*np.sin(theta),np.cos(theta)]).T
    if n==1: spins=spins[0]
    return spins


def gen_cubic_random_spin(n=1):
    spins = np.random.random(size=(n,3))*2-1
    for i in range(n): spins[i] /= np.linalg.norm(spins[i])
    if n==1: spins=spins[0]
    return spins


# this modified method can generate uniformly distributed random vectors
# but it is very inefficient, for details please refer to
# M. E. Muller, "A Note on a Method for Generating Points Uniformly on N-Dimensional Spheres." 1959.
def gen_modified_cubic_random_spin(n=1):
    m = n
    good_spins = np.zeros((n,3))
    while m>0:
        spins = np.random.random(size=(m,3))*2-1
        norms = np.linalg.norm(spins,axis=1)
        idx = np.where(norms<=1)
        ni = len(idx[0])
        good_spins[n-m:n-m+ni] = spins[idx]
        m -= ni
    for i in range(n): good_spins[i] /= np.linalg.norm(good_spins[i])
    if n==1: good_spins = good_spins[0]
    return good_spins


def gen_normal_random_spin(n=1):
    spins = np.random.multivariate_normal(np.zeros(3), np.eye(3), n)
    for i in range(n): spins[i] /= np.linalg.norm(spins[i])
    if n==1: spins=spins[0]
    return spins


# for a single spin

def gen_Ising_random_spin():
    return np.sign(np.random.random()*2-1)


def gen_XY_random_spin():
    phi = np.random.rand()*np.pi*2
    return np.array([np.cos(phi),np.sin(phi)])


def gen_Potts_random_spin(q=3):
    assert q>0, 'Potts-type random spins requires q>0!'
    phi = np.random.randint(q)*(np.pi*2)/q
    return np.array([np.cos(phi),np.sin(phi)])


def gen_Gaussian_random_spin(vec0,sigma):
    vec = vec0 + sigma * np.random.multivariate_normal(np.zeros(3), np.eye(3), 1)[0]
    return vec/np.linalg.norm(vec)


def gen_small_step_random_spin(vec0,sigma):
    vec = vec0 + sigma*gen_normal_random_spin()
    return vec/np.linalg.norm(vec)



def gen_random_spins_misc(nn,method='MultivarNormal',q_for_Potts=3):
    if method=='XY_random':           spins = np.array([gen_XY_random_spin() for ii in range(nn)])
    elif method=='Potts_random':        spins = np.array([gen_Potts_random_spin(q=q_for_Potts) for ii in range(nn)])
    elif method=='Cubic_random':      spins = gen_cubic_random_spin(n=nn)
    elif method=='Spherical_random':  spins = gen_spherical_random_spin(n=nn)
    elif method=='MultivarNormal':    spins = gen_normal_random_spin(n=nn)
    elif method=='Modified_cubic':    spins = gen_modified_cubic_random_spin(n=nn)
    else: exit('Method of generating random spins ({}) not recognized!'.format(method))
    return spins
    

# hist the points generated by random sapmling
# according to the surface area of the unit sphere
def hist_spins(ax,spins,tag='normal',color='b',nb=72):
    thetas = np.rad2deg(np.arccos(spins[:,2]))
    ax.hist(thetas,bins=nb,facecolor='none',edgecolor=color)
    ax.set_title(tag)
    ax.set_xticks(np.arange(0,200,30))
    ax.set_xlim(0,180)
    ax.set_xlabel('$\\theta$')


def hist_spins_by_theta(nn=50000,nb=72):
    import matplotlib.pyplot as plt

    fig,ax=plt.subplots(2,2,sharex=True,sharey=True,figsize=(10,10))

    tags = {
    'Cubic_random':   '$x,\ y,\ z\ \sim\ \mathscr{U}\ (0,1)$',
    'Sphere_random':  '$\\theta,\ \phi\ \sim\ \mathscr{U}\ (0,1)$',
    'MultivarNormal': '$[x,\ y,\ z]\ \sim\ \mathscr{N}\ (0,1)$',
    'Modified_cubic': '$\\theta,\ \phi\ \sim\ \mathscr{U}\ (0,1)\ &\ |v|\ <=\ 1$' }

    for ii,key in enumerate(tags.keys()):
        spins = gen_random_spins_misc(nn,method=key)
        hist_spins(ax[ii//2,ii%2], spins, tag = tags[key], nb=nb)

    ax[0,0].set_ylabel('$n\ (\\theta)$')
    ax[1,0].set_ylabel('$n\ (\\theta)$')
    for xx in [45,135]: ax[0,0].axvline(xx,ls='--',c='m',alpha=0.5,zorder=-1)
    fig.tight_layout()
    plt.show()


def calc_azimuthal_angle(spins,vv):
    dot_p = np.dot(spins,vv)
    dot_p[abs(dot_p)>1] = np.sign(dot_p[abs(dot_p)>1])
    theta = np.arccos(dot_p)
    return theta
