#!/usr/bin/env python

# compare the skyrmion configuration generated by this code (asd)
# and Spirit.configurations
# the lattice, sites, and spin configurations are carefully compared
# Shunhong Zhang <szhang2@ustc.edu.cn>
# July 10, 2021


import matplotlib.pyplot as plt
import numpy as np
from spirit import state,configuration,geometry,io,system
from asd.core.spin_configurations import *
from asd.utility.spin_visualize_tools import *
from asd.core.geometry import *


def asd_gen_spiral(nx=8,ny=8,q_vector=[0.1,0,0],axis=[0,0,1],theta=30,start_phi=0,display=True,lat_type='square'):
    latt,sites = build_latt(lat_type,nx,ny,1,return_neigh=False)
    nat = sites.shape[2]
    sp_lat = np.zeros((nx,ny,nat,3),float)
    kwargs=  dict(theta_cycle=-1,
    display=False,
    return_skyr_idx=True)
    sp_lat = init_spin_spiral(sp_lat,latt,sites,q_vector,theta,start_phi,axis=axis)

    sites_cart = np.dot(sites,latt)

    if display:
        fmt = '[{:4.2f}, {:4.2f}, {:4.2f}]'
        title = 'Generated by asd tools\n'
        title+= '$\\mathbf{q}$ = '
        title+= fmt.format(*tuple(q_vector))
        title+= '\n$\mathbf{n}$ = '
        title+= fmt.format(*tuple(axis))
        title+= '\n$\\theta={:4.2f} \degree$'.format(theta)
        plot_spin_2d(sites_cart,sp_lat,scatter_size=40,quiver_kws=quiver_kws,show=False,title=title)
    return sp_lat,sites_cart


def spirit_gen_spiral(nx=8,ny=8,q_vector=[0.1,0,0],axis=[0,0,1],theta=30,display=False,lat_type='honeycomb'):
    if lat_type=='honeycomb': cfg = 'input.cfg'
    elif lat_type=='square':  cfg = ''
    with state.State(cfg,quiet=True) as p_state:
        geometry.set_n_cells(p_state,[nx,ny,1],idx_image=0)
        configuration.spin_spiral(p_state,'',q_vector=q_vector,axis=axis,theta=theta,idx_image=0)
        pos = geometry.get_positions(p_state,idx_image=0)
        latt = geometry.get_bravais_vectors(p_state)
        latt = np.array([vec for vec in latt]).T[:2,:2]
        system.update_data(p_state,idx_image=0)
        io.image_write(p_state,'spiral_spirit.ovf',idx_image=0) 

    spins = parse_ovf('spiral_spirit.ovf')[1]
    if display:
        fmt = '[{:4.2f}, {:4.2f}, {:4.2f}]'
        title = 'Generated by Spirit\n'
        title+= '$\\mathbf{q}$ = '
        title+= fmt.format(*tuple(q_vector))
        title+= '\n$\mathbf{n}$ = '
        title+= fmt.format(*tuple(axis))
        title+= '\n$\\theta={:4.2f} \degree$'.format(theta)
        fig,ax,scat,qv,tl = plot_spin_2d(pos,spins,scatter_size=40,quiver_kws=quiver_kws,title=title,show=True)
    nat = spins.shape[0]//(nx*ny)
    sp_lat = np.swapaxes(spins.reshape(ny,nx,nat,3),0,1)
    sites_cart = np.swapaxes(pos.reshape(ny,nx,nat,3),0,1)[...,:2]
    return sp_lat,sites_cart


def compare_sites(sites_cart_1,sites_cart_2):
    fig,ax=plt.subplots()
    ax.scatter(sites_cart_1[...,0],sites_cart_1[...,1],c='g',label='asd',s=5)
    ax.scatter(sites_cart_2[...,0],sites_cart_2[...,1],facecolor='none',edgecolor='r',label='spirit',s=50)
    ax.legend(scatterpoints=1)
    ax.set_aspect('equal')
    ax.set_axis_off()
    fig.tight_layout()
    plt.show()
    for sites_cart in (sites_cart_1,sites_cart_2):
        mm=np.min(sites_cart,axis=(0,1,2))
        nn=np.max(sites_cart,axis=(0,1,2))
        print ('{:8.5f} {:8.5f}'.format(*tuple(mm)))
        print ('{:8.5f} {:8.5f}'.format(*tuple(nn)))
    print ('sites    consistency: ',np.allclose(sites_cart_1,sites_cart_2,atol=1e-6))


nx=10
ny=nx
radius=1.5

quiver_kws = dict(scale=1.2,units='x',pivot='mid')


if __name__=='__main__':
    kwargs = dict(
    nx=10,
    ny=6,
    q_vector=[0.1,0,0],
    axis=[1,1,0],
    theta=30,
    display=True,
    lat_type='square')

    sp_lat_1,sites_cart_1 = asd_gen_spiral(start_phi=10,**kwargs)
    sp_lat_2,sites_cart_2 = spirit_gen_spiral(**kwargs)


    kwargs = dict(
    nx=8,
    ny=8,
    q_vector=[0.125,0,0],
    axis=[1,0,0],
    theta=30,
    display=True,
    lat_type='honeycomb')

    sp_lat_1,sites_cart_1 = asd_gen_spiral(start_phi=90,**kwargs)
    sp_lat_2,sites_cart_2 = spirit_gen_spiral(**kwargs)
