# AUTOGENERATED! DO NOT EDIT! File to edit: StructureIO.ipynb (unless otherwise specified).

__all__ = ['atomic_number', 'atoms_color', 'periodic_table', 'Arrow3D', 'fancy_quiver3d', 'write_poscar',
           'export_poscar', 'InvokeMaterialsProject', 'get_kpath', 'read_ticks', 'str2kpath', 'get_kmesh', 'order',
           'rotation', 'get_bz', 'splot_bz', 'iplot_bz', 'to_R3', 'to_basis', 'kpoints2bz', 'fix_sites',
           'translate_poscar', 'get_pairs', 'iplot_lat', 'splot_lat', 'join_poscars', 'repeat_poscar', 'scale_poscar',
           'rotate_poscar', 'mirror_poscar', 'convert_poscar', 'get_transform_matrix', 'transform_poscar', 'add_vaccum',
           'add_atoms', 'remove_atoms', 'replace_atoms']

# Cell
import sys, os, re
import json
import numpy as np
from pathlib import Path
import requests as req
from collections import namedtuple
from itertools import product
from functools import lru_cache

from scipy.spatial import ConvexHull, Voronoi, KDTree
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import matplotlib.colors as mplc #For viewpoint
from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
# Inside packages import to work both with package and jupyter notebook.
try:
    from pivotpy import parser as vp, serializer
    from pivotpy import splots as sp
    from pivotpy import utils
except:
    import pivotpy.parser as vp
    import pivotpy.splots as sp
    import pivotpy.serializer as serializer
    import pivotpy.utils as utils


# These colors are taken from Mathematica's ColorData["Atoms"]
_atom_colors = {'H': (0.7, 0.8, 0.7), 'He': (0.8367, 1.0, 1.0),
    'Li': (0.7994, 0.9976, 0.5436), 'Be': (0.7706, 0.0442, 0.9643), 'B': (1.0, 0.5, 0), 'C': (0.4, 0.4, 0.4), 'N': (143/255,143/255,1), 'O': (0.8005, 0.1921, 0.2015), 'F': (128/255, 1, 0), 'Ne': (0.6773, 0.9553, 0.9284),
    'Na': (0.6587, 0.8428, 0.4922),'Mg': (0.6283, 0.0783, 0.8506),'Al': (173/255, 178/255, 189/255),'Si': (248/255, 209/255, 152/255),'P': (1,165/255,0),'S': (1,200/255,50/255),'Cl': (0,0.9,0),'Ar': (0.5461, 0.8921, 0.8442),
    'K':  (0.534, 0.7056, 0.4207), 'Ca': (0.4801, 0.0955, 0.7446), 'Sc': (0.902, 0.902, 0.902), 'Ti': (0.749, 0.7804, 0.7608), 'V': (0.651, 0.6706, 0.651), 'Cr': (0.5412, 0.7804, 0.6), 'Mn': (0.6118, 0.7804, 0.4784), 'Fe': (0.32,0.33,0.35),
    'Co': (0.9412, 0.6275, 0.5647), 'Ni': (141/255, 142/255, 140/255), 'Cu': (184/255, 115/255, 51/255), 'Zn': (186/255, 196/255, 200/255),'Ga': (90/255, 180/255, 189/255),'Ge': (0.6051, 0.5765, 0.6325),'As': (50/255,71/255,57/255),'Se': (0.9172, 0.0707, 0.6578),
    'Br': (161/255, 61/255, 45/255),'Kr': (0.426, 0.8104, 0.7475),'Rb': (0.4254, 0.5859, 0.3292),'Sr': (0.326, 0.096, 0.6464),'Y': (0.531, 1.0, 1.0),'Zr': (0.4586, 0.9186, 0.9175),'Nb': (0.385, 0.8417, 0.8349),'Mo': (0.3103, 0.7693, 0.7522),
    'Tc': (0.2345, 0.7015, 0.6694), 'Ru': (0.1575, 0.6382, 0.5865), 'Rh': (0.0793, 0.5795, 0.5036), 'Pd': (0.0, 0.5252, 0.4206), 'Ag': (0.7529, 0.7529, 0.7529), 'Cd': (0.8,0.67,0.73), 'In': (228/255, 228/255, 228/255), 'Sn': (0.398, 0.4956, 0.4915),
    'Sb': (158/255,99/255,181/255), 'Te': (0.8167, 0.0101, 0.4513), 'I': (48/255, 25/255, 52/255), 'Xe': (0.3169, 0.7103, 0.6381), 'Cs': (0.3328, 0.4837, 0.2177), 'Ba': (0.1659, 0.0797, 0.556), 'La': (0.9281, 0.3294, 0.7161), 'Ce': (0.8948, 0.3251, 0.7314),
    'Pr': (0.8652, 0.3153, 0.708), 'Nd': (0.8378, 0.3016, 0.663), 'Pm': (0.812, 0.2856, 0.6079), 'Sm': (0.7876, 0.2683, 0.5499), 'Eu': (0.7646, 0.2504, 0.4933), 'Gd': (0.7432, 0.2327, 0.4401), 'Tb': (0.7228, 0.2158, 0.3914), 'Dy': (0.7024, 0.2004, 0.3477),
    'Ho': (0.68, 0.1874, 0.3092), 'Er': (0.652, 0.1778, 0.2768), 'Tm': (0.6136, 0.173, 0.2515), 'Yb': (0.5579, 0.1749, 0.2346), 'Lu': (0.4757, 0.1856, 0.2276), 'Hf': (0.7815, 0.7166, 0.7174), 'Ta': (0.7344, 0.6835, 0.5445), 'W': (0.6812, 0.6368, 0.3604),
    'Re': (0.6052, 0.5563, 0.3676), 'Os': (0.5218, 0.4692, 0.3821), 'Ir': (0.4456, 0.3991, 0.3732), 'Pt': (0.8157, 0.8784, 0.8157), 'Au': (0.8, 0.7, 0.2), 'Hg': (0.7216, 0.8157, 0.7216), 'Tl': (0.651, 0.302, 0.3294), 'Pb': (0.3412, 0.3804, 0.349),
    'Bi': (10/255, 49/255, 93/255), 'Po': (0.6706, 0.0, 0.3608), 'At': (0.4588, 0.2706, 0.3098), 'Rn': (0.2188, 0.5916, 0.5161), 'Fr': (0.2563, 0.3989, 0.0861), 'Ra': (0.0, 0.0465, 0.4735), 'Ac': (0.322, 0.9885, 0.7169), 'Th': (0.3608, 0.943, 0.6717),
    'Pa': (0.3975, 0.8989, 0.628), 'U': (0.432, 0.856, 0.586), 'Np': (0.4645, 0.8145, 0.5455), 'Pu': (0.4949, 0.7744, 0.5067), 'Am': (0.5233, 0.7355, 0.4695), 'Cm': (0.5495, 0.698, 0.4338), 'Bk': (0.5736, 0.6618, 0.3998), 'Cf': (0.5957, 0.6269, 0.3675),
    'Es': (0.6156, 0.5934, 0.3367), 'Fm': (0.6335, 0.5612, 0.3075), 'Md': (0.6493, 0.5303, 0.2799), 'No': (0.663, 0.5007, 0.254), 'Lr': (0.6746, 0.4725, 0.2296), 'Rf': (0.6841, 0.4456, 0.2069), 'Db': (0.6915, 0.42, 0.1858), 'Sg': (0.6969, 0.3958, 0.1663),
    'Bh': (0.7001, 0.3728, 0.1484), 'Hs': (0.7013, 0.3512, 0.1321), 'Mt': (0.7004, 0.331, 0.1174), 'Ds': (0.6973, 0.312, 0.1043), 'Rg': (0.6922, 0.2944, 0.0928), 'Cn': (0.6851, 0.2781, 0.083), 'Nh': (0.6758, 0.2631, 0.0747), 'Fl': (0.6644, 0.2495, 0.0681),
    'Mc': (0.6509, 0.2372, 0.0631), 'Lv': (0.6354, 0.2262, 0.0597), 'Ts': (0.6354, 0.2262, 0.0566), 'Og': (0.6354, 0.2262, 0.0528)}

_atom_numbers = {k:i for i,k in enumerate(_atom_colors.keys())}

def atomic_number(atom):
    "Return atomic number of atom"
    return _atom_numbers[atom]

def atoms_color():
    "Defualt color per atom used for plotting the crystal lattice"
    return serializer.Dict2Data({k:[round(_v,4) for _v in rgb] for k,rgb in _atom_colors.items()})

def periodic_table():
    "Display colorerd elements in periodic table."
    _copy_names = np.array([f'$^{{{str(i+1)}}}${k}' for i,k in enumerate(_atom_colors.keys())])
    _copy_array = np.array(list(_atom_colors.values()))

    array = np.ones((180,3))
    names = ['' for i in range(180)] # keep as list before modification

    inds = [(0,0),(17,1),
            (18,2),(19,3),*[(30+i,4+i) for i in range(8)],
            *[(48+i,12+i) for i in range(6)],
            *[(54+i,18+i) for i in range(18)],
            *[(72+i,36+i) for i in range(18)],
            *[(90+i,54+i) for i in range(3)],*[(93+i,71+i) for i in range(15)],
            *[(108+i,86+i) for i in range(3)],*[(111+i,103+i) for i in range(15)],
            *[(147+i,57+i) for i in range(14)],
            *[(165+i,89+i) for i in range(14)]
            ]

    for i,j in inds:
        array[i] = _copy_array[j]
        names[i] = _copy_names[j]

    array = np.reshape(array,(10,18,3))
    names = np.reshape(names,(10,18))
    ax = sp.get_axes((9,4.5))
    ax.imshow(array)

    for i in range(18):
        for j in range(10):
            c = 'k' if np.linalg.norm(array[j,i]) > 1 else 'w'
            plt.text(i,j,names[j,i],color = c,ha='center',va='center')
    ax.set_axis_off()
    plt.show()

# Cell
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d

class Arrow3D(FancyArrowPatch):
    """Draw 3D fancy arrow."""
    def __init__(self, x, y, z, u, v, w, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
        self._verts3d = [x,x+u], [y,y+v], [z,z+w]

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) #renderer>M for < 3.4 but we don't need it
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super().draw(renderer)

    def do_3d_projection(self, renderer): # For matplotlib >= 3.5
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return min(zs)

    def on(self,ax):
        ax.add_artist(self)

def fancy_quiver3d(X,Y,Z,U,V,W,ax=None,C = 'r',L = 0.7,mutation_scale=10,**kwargs):
    """Plots 3D arrows on a given ax. See [FancyArrowPatch](https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.patches.FancyArrowPatch.html).
    - **Parameters**
        - X, Y, Z : 1D arrays of coordinates of arrows' tail point.
        - U, V, W : 1D arrays of dx,dy,dz of arrows.
        - ax: 3D axes, if not given, auto created.
        - C : 1D colors array mapping for arrows. Could be one color.
        - L : 1D linwidths array mapping for arrows. Could be one linewidth.
        - mutation_scale: Arrow head width/size scale. Default is 10.
        - kwargs: FancyArrowPatch's keyword arguments excluding positions,color, lw and mutation_scale, shrinkA, shrinkB which are already used. An important keyword argument is `arrowstyle` which could be '->','-|>', their inverted forms and many more. See on matplotlib.
    """
    if not ax:
        ax = sp.get_axes(figsize=(3.4,3.4),axes_3d=True) # Same aspect ratio.
    if not isinstance(C,(list,np.ndarray)):
        C = [[*mplc.to_rgb(C)] for x in X]
    C = np.array(C) # Safe for list

    if not isinstance(L,(list,np.ndarray)):
        L = [L for x in X]
    args_dict = dict(mutation_scale=mutation_scale,shrinkA=0, shrinkB=0)
    for x,y,z,u,v,w,c,l in zip(X,Y,Z,U,V,W,C,L):
        Arrow3D(x, y, z, u, v, w, color=c,lw=l,**args_dict,**kwargs).on(ax)

    return ax

# Cell
def write_poscar(poscar_data, outfile = None, sd_list = None, overwrite = False):
    """Writes poscar data object to a file or returns string
    - **Parameters**
        - poscar_data: Output of `export_poscar`,`join_poscars` etc.
        - outfile  : str,file path to write on.
        - sd_list  : A list ['T T T','F F F',...] strings to turn on selective dynamics at required sites. len(sd_list)==len(sites) should hold.
        - overwrite: bool, if file already exists, overwrite=True changes it.

    **Note**: POSCAR is only written in direct format even if it was loaded from cartesian format.
    """
    _comment = poscar_data.extra_info.comment
    out_str = f'{poscar_data.SYSTEM}  # ' + (_comment or 'Created by Pivopty')
    scale = poscar_data.extra_info.scale
    out_str += "\n  {:<20.14f}\n".format(scale)
    out_str += '\n'.join(["{:>22.16f}{:>22.16f}{:>22.16f}".format(*a) for a in poscar_data.basis/scale])
    uelems = poscar_data.unique.to_dict()
    out_str += "\n  " + '    '.join(uelems.keys())
    out_str += "\n  " + '    '.join([str(len(v)) for v in uelems.values()])
    if sd_list:
        out_str += "\nSelective Dynamics"

    out_str += "\nDirect\n"
    positions = poscar_data.positions
    pos_list = ["{:>21.16f}{:>21.16f}{:>21.16f}".format(*a) for a in positions]
    if sd_list:
        if len(pos_list) != len(sd_list):
            raise ValueError("len(sd_list) != len(sites).")
        pos_list = [f"{p}   {s}" for p,s in zip(pos_list,sd_list)]
    out_str += '\n'.join(pos_list)
    if outfile:
        if not os.path.isfile(outfile):
            with open(outfile,'w') as f:
                f.write(out_str)

        elif overwrite and os.path.isfile(outfile):
            with open(outfile,'w') as f:
                f.write(out_str)
        else:
            raise FileExistsError(f"{outfile!r} exists, can not overwrite, \nuse overwrite=True if you want to chnage.")
    else:
        print(out_str)

def export_poscar(path = None,content = None):
    """Export POSCAR file to python objects. Only Direct POSCAR supported.
    - **Parameters**
        - path: Path/to/POSCAR file. Auto picks in CWD.
        - content: POSCAR content as string, This takes precedence to path.
    """
    if content and isinstance(content,str):
        path = content.splitlines() # This acts as slice for islice2array.
    elif not path:
        path = './POSCAR'
        if not os.path.isfile(path):
            raise FileNotFoundError(f"{path!r} not found.")
    header = vp.islice2array(path,start=0,nlines=1,raw=True,exclude=None).split('#',1)
    SYSTEM = header[0].strip()
    comment = header[1].strip() if len(header) > 1 else 'Exported by Pivopty'

    scale = float(vp.islice2array(path,start=1,nlines=1,exclude=None,raw=True).strip())
    if scale < 0: # If that is for volume
        scale = 1
    basis = scale*vp.islice2array(path,start=2,nlines=3,exclude=None).reshape((-1,3))
    #volume = np.linalg.det(basis)
    #rec_basis = np.linalg.inv(basis).T # general formula
    out_dict = {'SYSTEM':SYSTEM,#'volume':volume,
                'basis':basis,#'rec_basis':rec_basis,
                'extra_info':{'comment':comment,'scale':scale}}

    elems = vp.islice2array(path,raw=True,start=5,nlines=1,exclude=None).split()
    ions = vp.islice2array(path,start=6,nlines=1,exclude=None)
    N = np.sum(ions).astype(int)
    inds = np.cumsum([0,*ions]).astype(int)
    # Check Cartesian and Selective Dynamics
    lines = vp.islice2array(path,start=7,nlines=2,exclude=None,raw=True).splitlines()
    lines = [l.strip() for l in lines] # remove whitespace or tabs
    out_dict['extra_info']['cartesian'] = True if ((lines[0][0] in 'cCkK') or (lines[1][0] in 'cCkK')) else False
    # Two lines are excluded in below command before start. so start = 7-2
    positions = vp.islice2array(path,start=5,exclude="^\s+[a-zA-Z]|^[a-zA-Z]",cols=[0,1,2]).reshape((-1,3))[:N]
    if out_dict['extra_info']['cartesian']:
        raise NotImplementedError(("Cartesian format is not supported for POSCAR file, "
                                  "but structure in vasprun.xml is readable using any of "
                                  "pivotpy.parser.[export_vasprun, export_spin_data, get_structure]"))

    unique_d = {}
    for i,e in enumerate(elems):
        unique_d.update({e:range(inds[i],inds[i+1])})

    elem_labels = []
    for i, name in enumerate(elems):
        for ind in range(inds[i],inds[i+1]):
            elem_labels.append(f"{name} {str(ind - inds[i] + 1)}")
    out_dict.update({'positions':positions,#'labels':elem_labels,
                     'unique':unique_d})
    return serializer.PoscarData(out_dict)

# Cell
def _save_mp_API(api_key):
    """
    - Save materials project api key for autoload in functions.
    """
    home = str(Path.home())
    file = os.path.join(home,'.pivotpyrc')
    lines = []
    if os.path.isfile(file):
        with open(file,'r') as fr:
            lines = fr.readlines()
            lines = [line for line in lines if 'MP_API_KEY' not in line]

    with open(file,'w') as fw:
        fw.write("MP_API_KEY = {}".format(api_key))
        for line in lines:
            fw.write(line)

# Cell
def _load_mp_data(formula,api_key=None,mp_id=None,max_sites = None, min_sites = None):
    """
    - Returns fetched data using request api of python form materials project website.
    - **Parameters**
        - formula  : Material formula such as 'NaCl'.
        - api_key  : API key for your account from material project site. Auto picks if you already used `_save_mp_API` function.
        - mp_id     : Optional, you can specify material ID to filter results.
        - max_sites : Maximum number of sites. If None, sets `min_sites + 1`, if `min_sites = None`, gets all data.
        - min_sites : Minimum number of sites. If None, sets `max_sites + 1`, if `max_sites = None`, gets all data.
    """
    if api_key is None:
        try:
            home = str(Path.home())
            file = os.path.join(home,'.pivotpyrc')
            with open(file,'r') as f:
                lines=f.readlines()
                for line in lines:
                    if 'MP_API_KEY' in line:
                        api_key = line.split('=')[1].strip()
        except:
            raise ValueError("api_key not given. provide in argument or generate in file using `_save_mp_API(your_mp_api_key)")

    #url must be a raw string
    url = r"https://legacy.materialsproject.org/rest/v2/materials/{}/vasp?API_KEY={}".format(formula,api_key)
    resp = req.request(method='GET',url=url)
    if resp.status_code != 200:
        raise ValueError("Error in fetching data from materials project. Try again!")

    jl = json.loads(resp.text)
    if not 'response' in jl: #check if response
        raise ValueError("Either formula {!r} or API_KEY is incorrect.".format(formula))

    all_res = jl['response']

    if max_sites != None and min_sites != None:
        lower, upper = min_sites, max_sites
    elif max_sites == None and min_sites != None:
        lower, upper = min_sites, min_sites + 1
    elif max_sites != None and min_sites == None:
        lower, upper = max_sites - 1, max_sites
    else:
        lower, upper = '-1', '-1' # Unknown

    if lower != '-1' and upper != '-1':
        sel_res=[]
        for res in all_res:
            if res['nsites'] <= upper and res['nsites'] >= lower:
                sel_res.append(res)
        return sel_res
    # Filter to mp_id at last. more preferred
    if mp_id !=None:
        for res in all_res:
            if mp_id == res['material_id']:
                return [res]
    return all_res

# Cell
class InvokeMaterialsProject:
    """Connect to materials project and get data using `api_key` from their site.
    Usage:
    ```python
    from pivotpyr.sio import InvokeMaterialsProject # or import pivotpy.InvokeMaterialsProject as InvokeMaterialsProject
    mp = InvokeMaterialsProject(api_key='your_api_key')
    outputs = mp.request(formula='NaCl') #returns list of structures from response
    outupts[0].export_poscar() #returns poscar data
    outputs[0].cif #returns cif data
    ```"""
    def __init__(self,api_key=None):
        "Request Materials Project acess. api_key is on their site. Your only need once and it is saved for later."
        self.api_key = api_key
        self.__response = None
        self.success = False

    def save_api_key(self,api_key):
        "Save api_key for auto reloading later."
        _save_mp_API(api_key)

    @lru_cache(maxsize=2) #cache for 2 calls
    def request(self,formula,mp_id=None,max_sites = None,min_sites=None):
        """Fetch data using request api of python form materials project website. After request, you can access `cifs` and `poscars`.
        - **Parameters**
            - formula  : Material formula such as 'NaCl'.
            - mp_id     : Optional, you can specify material ID to filter results.
            - max_sites : Maximum number of sites. If None, sets `min_sites + 1`, if `min_sites = None`, gets all data.
            - min_sites : Minimum number of sites. If None, sets `max_sites + 1`, if `max_sites = None`, gets all data.
        """
        self.__response = _load_mp_data(formula = formula,api_key = self.api_key, mp_id = mp_id, max_sites = max_sites,min_sites=min_sites)
        if self.__response == []:
            raise req.HTTPError("Error in request. Check your api_key or formula.")

        class Structure:
            def __init__(self,response):
                self._cif    = response['cif']
                self.symbol  = response['spacegroup']['symbol']
                self.crystal = response['spacegroup']['crystal_system']
                self.unit    = response['unit_cell_formula']
                self.mp_id   = response['material_id']

            @property
            def cif(self):
                return self._cif

            def __repr__(self):
                return f"Structure(unit={self.unit},mp_id={self.mp_id!r},symbol={self.symbol!r},crystal={self.crystal!r},cif='{self._cif[:10]}...')"

            def write_cif(self,outfile = None):
                if isinstance(outfile,str):
                    with open(outfile,'w') as f:
                        f.write(self._cif)
                else:
                    print(self._cif)

            def write_poscar(self,outfile = None, overwrite = False):
                "Use `pivotpy.api.POSCAR.write/pivotpy.sio.write_poscar` if you need extra options."
                write_poscar(self.export_poscar(),outfile = outfile, overwrite = overwrite)

            def export_poscar(self):
                lines = self._cif.split('\n')
                if '' in lines.copy():
                    lines.remove('')
                abc = []
                abc_ang = []
                index = 0
                for ys in lines:
                    if '_cell' in ys:
                        if '_length' in ys:
                            abc.append(ys.split()[1])
                        if '_angle' in ys:
                            abc_ang.append(ys.split()[1])
                        if '_volume' in ys:
                            volume = float(ys.split()[1])
                    if '_structural' in ys:
                        top = ys.split()[1] + f" # [{self.mp_id!r}][{self.symbol!r}][{self.crystal!r}] Created by pivotpy using Materials Project Database"
                for i,ys in enumerate(lines):
                    if '_atom_site_occupancy' in ys:
                        index = i +1 # start collecting pos.
                poses = lines[index:]
                pos_str = ""
                for pos in poses:
                    s_p = pos.split()
                    pos_str += "{0:>12}  {1:>12}  {2:>12}  {3}\n".format(*s_p[3:6],s_p[0])

                # ======== Cleaning ===========
                abc_ang = [float(ang) for ang in abc_ang]
                abc     = [float(a) for a in abc]
                a = "{0:>22.16f} {1:>22.16f} {2:>22.16f}".format(1.0,0.0,0.0) # lattic vector a.
                to_rad = 0.017453292519
                gamma = abc_ang[2]*to_rad
                bx,by = abc[1]*np.cos(gamma),abc[1]*np.sin(gamma)
                b = "{0:>22.16f} {1:>22.16f} {2:>22.16f}".format(bx/abc[0],by/abc[0],0.0) # lattic vector b.
                cz = volume/(abc[0]*by)
                cx = abc[2]*np.cos(abc_ang[1]*to_rad)
                cy = (abc[1]*abc[2]*np.cos(abc_ang[0]*to_rad)-bx*cx)/by
                c = "{0:>22.16f} {1:>22.16f} {2:>22.16f}".format(cx/abc[0],cy/abc[0],cz/abc[0]) # lattic vector b.

                elems = [elem for elem in self.unit.keys()]
                elems = '\t'.join(elems)
                nums  = [str(int(self.unit[elem])) for elem in self.unit.keys()]
                nums  = '\t'.join(nums)
                content = f"{top}\n  {abc[0]}\n {a}\n {b}\n {c}\n  {elems}\n  {nums}\nDirect\n{pos_str}"
                return export_poscar(content = content)


        # get cifs
        structures = []
        for res in self.__response:
            structures.append(Structure(res))

        self.success = True # set success flag
        return structures

# Cell
def get_kpath(*patches, n = 5,weight= None ,ibzkpt = None,outfile=None, rec_basis = None):
    """
    Generate list of kpoints along high symmetry path. Options are write to file or return KPOINTS list.
    It generates uniformly spaced point with input `n` as just a scale factor of number of points per unit length.
    You can also specify custom number of kpoints in an interval by putting number of kpoints as 4th entry in left kpoint.
    - **Parameters**
        - *ptaches : Any number of disconnected patches where a single patch is a dictionary like {'label': (x,y,z,[N]), ...} where x,y,z is high symmetry point and
                    N (optional) is number of points in current inteval, points in a connected path patch are at least two i.e. `{'p1':[x1,y1,z1],'p2':[x2,y2,z2]}`.
                    A key of a patch should be string reperenting the label of the high symmetry point. A key that starts with '_' is ignored, so you can add points without high symmetry points as well.
        - n        : int, number per length of body diagonal of rec_basis, this makes uniform steps based on distance between points.
        - weight : Float, if None, auto generates weights.
        - ibzkpt : Path to ibzkpt file, required for HSE calculations.
        - outfile: Path/to/file to write kpoints.
        - rec_basis: Reciprocal basis 3x3 array to use for calculating uniform points.

    If `outfile = None`, KPONITS file content is printed.
    """
    if len(patches) == 0:
        raise ValueError("Please provide at least one high symmetry path consisting of two points.")

    hsk_list, labels = [], []
    for patch in patches:
        if not isinstance(patch,dict):
            raise TypeError("Patche must be a dictionary as {'label': (x,y,z,[N]), ...}")
        if len(patch.keys()) < 2:
            raise ValueError("Please provide at least one high symmetry path consisting of two points.")
        _patch = []
        for k,v in patch.items():
            if not isinstance(k,str):
                raise TypeError("Label must be a string")
            if (not isinstance(v,(list,tuple,set,np.ndarray))) and (len(v) not in [3,4]):
                raise TypeError("Value must be a list or tuple of length 3 or 4, like (x,y,z,[N])")
            labels.append('skip' if k.startswith('_') else k)
            _patch.append(v)
        hsk_list.append(_patch)

    xs,ys,zs, inds,joinat = [],[],[],[0],[] # 0 in inds list is important
    _labels = []
    _len_prev = 0
    for j,a in enumerate(hsk_list):
        for i in range(len(a)-1):
            try:
                _m = a[i][3] # number of points given explicitly.
            except IndexError:
                if rec_basis is not None and np.size(rec_basis) == 9:
                    basis = np.array(rec_basis)
                    coords = to_R3(basis,[a[i][:3],a[i+1][:3]])
                    largest_dist = np.linalg.norm(basis.sum(axis=0)) # body diagonal
                    _m = np.rint(np.linalg.norm(coords[0] - coords[1])*n/largest_dist).astype(int)
                else:
                    _vec = [_a-_b for _a,_b in zip(a[i][:3],a[i+1] )] # restruct point if 4 entries
                    _m = np.rint(np.linalg.norm(_vec)*n).astype(int) # Calculate

            _m = _m if _m >= 2 else 2 # minimum of 2 points in a path

            inds.append(inds[-1]+_m) #Append first then do next
            _labels.append(labels[i+_len_prev])
            if j !=0 and i == 0:
                joinat.append(inds[-2]) # Add previous in joinpath and label
                if 'skip' not in _labels[-2]:
                    _labels[-1] = _labels[-2] + '|' + _labels[-1]
                    _labels = [*_labels[:-2],_labels[-1]] # Drop the label we added before

            xs.append(list(np.linspace(a[i][0],a[i+1][0],_m)))
            ys.append(list(np.linspace(a[i][1],a[i+1][1],_m)))
            zs.append(list(np.linspace(a[i][2],a[i+1][2],_m)))

        _labels.append(labels[len(a) -1 +_len_prev]) # Add last in current interval
        _len_prev += len(a)

    xs = [y for z in xs for y in z] #flatten values.
    ys = [y for z in ys for y in z]
    zs = [y for z in zs for y in z]

    if weight == None and xs:
        weight = 1/len(xs)

    out_str = ["{0:>16.10f}{1:>16.10f}{2:>16.10f}{3:>12.6f}".format(x,y,z,weight) for x,y,z in zip(xs,ys,zs)]
    out_str = '\n'.join(out_str)
    N = np.size(xs)
    if ibzkpt != None:
        if os.path.isfile(ibzkpt):
            with open(ibzkpt,'r') as f:
                lines = f.readlines()

            N = int(lines[1].strip())+N # Update N.
            slines = lines[3:N+4]
            ibz_str = ''.join(slines)
            out_str = "{}\n{}".format(ibz_str,out_str) # Update out_str
    if inds:
        inds[-1] = -1 # last index to -1

    inds = [i for k,i in enumerate(inds) if 'skip' != _labels[k]]
    _labels = [l.replace('|skip','') for l in _labels if l != 'skip']
    top_str = "Automatically generated using PivotPy with HSK-INDS = {}, LABELS = {}, SEG-INDS = {}\n\t{}\nReciprocal Lattice".format(inds,_labels,joinat,N)
    out_str = "{}\n{}".format(top_str,out_str)
    if outfile != None:
        with open(outfile,'w') as f:
            f.write(out_str)
    else:
        print(out_str)

def read_ticks(kpoints_file_path):
    "Reads ticks values and labels in header of kpoint file. Returns dictionary of `ktick_inds`,`ktick_vals`,`kseg_inds` that can be unpacked to plotting functions. If not exist in header, returns empty values(still valid)."
    out_dict = dict(ktick_inds=[],ktick_vals=[],kseg_inds=[])
    if os.path.isfile(kpoints_file_path):
        head = vp.islice2array(kpoints_file_path,exclude=None,raw=True,nlines=1)
        if 'HSK-INDS' in head:
            hsk = head.split('HSK-INDS')[1].split(']')[0].split('[')[1].split(',')
            out_dict['ktick_inds'] = [int(h) for h in hsk if h]
        if 'LABELS' in head:
            labs = head.split('LABELS')[1].split(']')[0].split('[')[1].split(',')
            out_dict['ktick_vals'] = [l.replace("'","").replace('"','').strip() for l in labs if l]
        if 'SEG-INDS' in head:
            segs = head.split('SEG-INDS')[1].split(']')[0].split('[')[1].split(',')
            out_dict['kseg_inds'] = [int(s) for s in segs if s]
    return out_dict

# Cell
def str2kpath(kpath_str,n = 5, weight = None, ibzkpt = None, outfile = None, rec_basis = None):
    """Get Kpath from a string of kpoints (Line-Mode like). Useful in Terminal.
    - **Parameters**
        - kpath_str: str, a multiline string similiar to line mode of KPOINTS, initial 4 lines are not required.
            - If you do not want to label a point, label it as 'skip' and it will be removed.
            - You can add an interger at end of a line to customize number of points in a given patch.
            - Each empty line breaks the path, so similar points before and after empty line are useless here.
        - n      : int, number per length of body diagonal of rec_basis, this makes uniform steps based on distance between points.
        - weight : Float, if None, auto generates weights.
        - ibzkpt : Path to ibzkpt file, required for HSE calculations.
        - outfile: Path/to/file to write kpoints.
        - rec_basis: Reciprocal basis 3x3 array to use for calculating uniform points.

    - **Example**
        > str2kpath('''0 0 0 !$\Gamma$ 3
                    0.25 0.25 0.25 !L''')
        > Automatically generated using PivotPy with HSK-INDS = [0, -1], LABELS = ['$\\Gamma$', 'L'], SEG-INDS = []
	    >   3
        > Reciprocal Lattice
        >   0.0000000000    0.0000000000    0.0000000000    0.333333
        >   0.1250000000    0.1250000000    0.1250000000    0.333333
        >   0.2500000000    0.2500000000    0.2500000000    0.333333
    """
    lines = kpath_str.splitlines()

    skipN = 0
    for _n, line in enumerate(lines[:6]): #Handle strings from AFLOW-like softwares
        if line.strip().isalpha():
            skipN = _n + 1

    where_blanks = [i for i,line in enumerate(lines) if line.strip() == '']

    hsk_list, labels = [],[]
    for j,line in enumerate(lines[skipN:]):
        if line.strip():
            _labs = re.findall('\$\\\\[a-zA-Z]+\$|[a-zA-Z]+|[α-ωΑ-Ω]+|\|', line)
            labels.append(_labs[0] if _labs else '')
            _ks = re.findall('[-+\d]*[.][\d+]+|[-+]*\d+/\d+|[-+]*\d+',line)
            _ks = [[float(k) for k in w.split('/')] if '/' in w else float(w) for w in _ks]

            for i,_k in enumerate(_ks):
                if type(_k) == list and len(_k) == 2:
                    _ks[i] = _k[0]/_k[1]
                elif type(_k) == list and len(_k) != 2:
                    print(f'Check if you provide fraction correctly in line {j+1+skipN}!')

            if len(_ks) == 4:
                try: _ks[3] = int(_ks[3])
                except: print(f'4th number in line {j+1+skipN} should be integer!')

            hsk_list.append(_ks)

    if where_blanks:
        filtered = [w-i for i,w in enumerate(where_blanks)]
        where_blanks = np.unique([0,*filtered,len(hsk_list)]).tolist()

    patches = [] if where_blanks else [{k:v for v,k in zip(hsk_list,labels)},] #Fix up both
    for a,b in zip(where_blanks[:-1], where_blanks[1:]):
        if b - a < 2:
            raise ValueError(f"There should be at least two points in a patch of path at line {a+1}!")
        patches.append({k:v for v,k in zip(hsk_list[a:b],labels[a:b])})

    return get_kpath(*patches,n=n,weight=weight,ibzkpt=ibzkpt,outfile=outfile, rec_basis = rec_basis)

# Cell
def _get_basis(path_pos):
    """Returns given(computed) and inverted basis as tuple(given,inverted).
    - **Parameters**
        - path_pos: path/to/POSCAR or 3 given vectors as rows of a matrix."""
    if isinstance(path_pos,(list,np.ndarray)) and np.ndim(path_pos) ==2:
        basis = np.array(path_pos)
    elif isinstance(path_pos,str) or isinstance(path_pos,type(None)):
        basis = export_poscar(path_pos).basis
    else:
        raise FileNotFoundError("{!r} does not exist or not 3 by 3 list.".format(path_pos))
    # Process. 2π is not included in vasp output
    rec_basis = np.linalg.inv(basis).T # Compact Formula
    Basis = namedtuple('Basis', ['given', 'inverted'])
    return Basis(basis,rec_basis)

# Cell
def get_kmesh(poscar_data, *args, shift = 0, weight = None, cartesian = False, ibzkpt= None, outfile=None):
    """**Note**: Use `pivotpy.POSCAR.get_kmesh` to get k-mesh based on current POSCAR.
    - Generates uniform mesh of kpoints. Options are write to file, or return KPOINTS list.
    - **Parameters**
        - poscar_data: export_poscar() or export_vasprun().poscar().
        - *args: 1 or 3 integers which decide shape of mesh. If 1, mesh points equally spaced based on data from POSCAR.
        - shift  : Only works if cartesian = False. Defualt is 0. Could be a number or list of three numbers to add to interval [0,1].
        - weight : Float, if None, auto generates weights.
        - cartesian: If True, generates cartesian mesh.
        - ibzkpt : Path to ibzkpt file, required for HSE calculations.
        - outfile: Path/to/file to write kpoints.

    If `outfile = None`, KPOINTS file content is printed."""
    if len(args) not in [1,3]:
        raise ValueError("get_kmesh() takes 1 or 3 args!")

    if cartesian:
        norms = np.ptp(poscar_data.rec_basis,axis=0)
    else:
        norms = np.linalg.norm(poscar_data.rec_basis, axis = 1)

    if len(args) == 1:
        if not isinstance(args[0],int):
            raise ValueError("get_kmesh expects integer for first positional argument!")
        nx,ny,nz = [args[0] for _ in range(3)]

        weights = norms/np.max(norms) # For making largest side at given n
        nx, ny, nz = np.rint(weights*args[0]).astype(int)

    elif len(args) == 3:
        for i,a in enumerate(args):
            if not isinstance(a,int):
                raise ValueError("get_kmesh expects integer at position {}!".format(i))
        nx,ny,nz = list(args)

    low,high = np.array([[0,0,0],[1,1,1]]) + shift
    if cartesian:
        verts = get_bz(poscar_data.basis, primitive=False).vertices
        low, high = np.min(verts,axis=0), np.max(verts,axis=0)
        low = (low * 2 * np.pi / poscar_data.extra_info.scale).round(12) # Cartesian KPOINTS are in unit of 2pi/SCALE
        high = (high * 2 * np.pi / poscar_data.extra_info.scale).round(12)

    (lx,ly,lz),(hx,hy,hz) = low,high
    points = []
    for k in np.linspace(lz,hz,nz, endpoint = True):
        for j in np.linspace(ly,hy,ny, endpoint = True):
            for i in np.linspace(lx,hx,nx, endpoint = True):
                points.append([i,j,k])

    points = np.array(points)
    points[np.abs(points) < 1e-10] = 0

    if len(points) == 0:
        raise ValueError('No KPOINTS in BZ from given input. Try larger input!')

    if weight == None and len(points) != 0:
        weight = float(1/len(points))

    out_str = ["{0:>16.10f}{1:>16.10f}{2:>16.10f}{3:>12.6f}".format(x,y,z,weight) for x,y,z in points]
    out_str = '\n'.join(out_str)
    N = len(points)
    if ibzkpt and os.path.isfile(ibzkpt):
        with open(ibzkpt,'r') as f:
            lines = f.readlines()

        if (cartesian == False) and (lines[2].strip()[0] in 'cCkK'):
            raise ValueError("ibzkpt file is in cartesian coordinates, use get_kmesh(...,cartesian = True)!")

        N = int(lines[1].strip())+N # Update N.
        slines = lines[3:N+4]
        ibz_str = ''.join(slines)
        out_str = "{}\n{}".format(ibz_str,out_str) # Update out_str
    mode = 'Reciprocal' if cartesian == False else 'Cartesian'
    top_str = "Generated uniform mesh using PivotPy, GRID-SHAPE = [{},{},{}]\n\t{}\n{}".format(nx,ny,nz,N,mode)
    out_str = "{}\n{}".format(top_str,out_str)
    if outfile != None:
        with open(outfile,'w') as f:
            f.write(out_str)
    else:
        print(out_str)

# Cell
def _tan_inv(vy,vx):
    """
    - Returns full angle from x-axis counter clockwise.
    - **Parameters**
        - vy : Perpendicular componet of vector including sign.
        - vx : Base compoent of vector including sign.
    """
    angle = 0  # Place hodler to handle exceptions
    if vx == 0 and vy == 0:
        angle = 0
    elif vx == 0 and np.sign(vy) == -1:
        angle = 3*np.pi/2
    elif vx == 0 and np.sign(vy) == 1:
        angle = np.pi/2
    else:
        theta = abs(np.arctan(vy/vx))
        if np.sign(vx) == 1 and np.sign(vy) == 1:
            angle = theta
        if np.sign(vx) == -1 and np.sign(vy) == 1:
            angle = np.pi - theta
        if np.sign(vx) == -1 and np.sign(vy) == -1:
            angle = np.pi + theta
        if np.sign(vx) == 1 and np.sign(vy) == -1:
            angle = 2*np.pi - theta
        if np.sign(vx) == -1 and vy == 0:
            angle = np.pi
        if np.sign(vx) == 1 and vy == 0:
            angle = 2*np.pi
    return angle

def order(points,loop=True):
    """
    - Returns indices of counterclockwise ordered vertices of a plane in 3D.
    - **Parameters**
        - points: numpy array of shape (N,3) or List[List(len=3)].
        - loop  : Default is True and appends start point at end to make a loop.
    - **Example**
        > pts = np.array([[1,0,3],[0,0,0],[0,1,2]])
        > inds = order(pts)
        > pts[inds]
        ```
        array([[1, 2, 3],
               [0, 0, 0],
               [1, 0, 3]
               [0, 1, 2]])
        ```
    """
    points = np.array(points) # Make array.
    # Fix points if start point is zero.
    if np.sum(points[0]) == 0:
        points = points + 0.5

    center = np.mean(points,axis=0) # 3D cent point.
    vectors = points - center # Relative to center

    ex = vectors[0]/np.linalg.norm(vectors[0])  # i
    ey = np.cross(center,ex)
    ey = ey/np.linalg.norm(ey)  # j

    angles= []
    for i, v in enumerate(vectors):
        vx = np.dot(v,ex)
        vy = np.dot(v,ey)
        angle = _tan_inv(vy,vx)
        angles.append([i,angle])

    s_angs = np.array(angles)
    ss = s_angs[s_angs[:,1].argsort()] #Sort it.

    if loop: # Add first at end for completing loop.
        ss = np.concatenate((ss,[ss[0]]))

    return ss[:,0].astype(int) # Order indices.


def _out_bz_plane(test_point,plane):
    """
    - Returns True if test_point is between plane and origin. Could be used to sample BZ mesh in place of ConvexHull.
    - **Parameters**
        - test_points: 3D point.
        - plane      : List of at least three coplanar 3D points.
    """
    outside = True
    p_test = np.array(test_point)
    plane = np.unique(plane,axis=0) #Avoid looped shape.
    c = np.mean(plane,axis=0) #center
    _dot_ = np.dot(p_test-c,c)
    if _dot_ < -1e-5:
        outside = False
    return outside


def _rad_angle(v1,v2):
    """
    - Returns interier angle between two vectors.
    - **Parameters**
        - v1,v2 : Two vectors/points in 3D.
    """
    v1 = np.array(v1)
    v2 = np.array(v2)
    norm  = np.linalg.norm(v1)*np.linalg.norm(v2)
    dot_p = np.round(np.dot(v1,v2)/norm,12)
    angle = np.arccos(dot_p)
    return angle

from scipy.spatial.transform import Rotation
def rotation(angle_deg,axis_vec):
    """Get a scipy Rotation object at given `angle_deg` around `axis_vec`.
    Usage:
        rot = rotation(60,[0,0,1])
        rot.apply([1,1,1])
        [-0.3660254  1.3660254  1.] #give this
    """
    axis_vec = np.array(axis_vec)/np.linalg.norm(axis_vec) # Normalization
    angle_rad = np.deg2rad(angle_deg)
    return Rotation.from_rotvec(angle_rad * axis_vec)

# Cell
def get_bz(path_pos = None,loop = True,digits=8,primitive=False):
    """
    - Return required information to construct first Brillouin zone in form of tuple (basis, normals, vertices, faces).
    - **Parameters**
        - path_pos : POSCAR file path or list of 3 Real space vectors in 3D as list[list,list,list].
        - loop   : If True, joins the last vertex of a BZ plane to starting vertex in order to complete loop.
        - digits : int, rounding off decimal places, no effect on intermediate calculations, just for pretty final results.
        - primitive: Defualt is False and returns Wigner-Seitz cell, If True returns parallelipiped in rec_basis.

    - **Attributes**
        - basis   : get_bz().basis, recprocal lattice basis.
        - normals : get_bz().normals, all vertors that are perpendicular BZ faces/planes.
        - vertices: get_bz().vertices, all vertices of BZ, could be input into ConvexHull for constructing 3D BZ.
        - faces   : get_bz().faces, vertices arranged into faces, could be input to Poly3DCollection of matplotlib for creating BZ from faces' patches.
        - specials : get_bz().specials, Data with attributes `coords`,`kpoints` and `near` in on-one correspondence for high symmetry KPOINTS in recirprocal coordinates space. `near` gives indices of nearest special points around a vertex. All vertices with z > 0 are included.
    """
    basis = _get_basis(path_pos).inverted # Reads
    b1, b2, b3 = basis # basis are reciprocal basis
    # Get all vectors for BZ
    if primitive:
        b0 = np.array([0,0,0])
        bd = b1+b2+b3 #Diagonal point
        faces = np.array([
                 [b0, b1, b1+b2, b2],
                 [b0,b2,b2+b3,b3],
                 [b0,b3,b3+b1,b1],
                 [b1,b1+b2,b1+b3,bd],
                 [b2,b2+b1,b2+b3,bd],
                 [b3,b3+b1,b3+b2,bd]
                ])
    else:
        vectors = []
        for i,j,k in product([0,1,-1],[0,1,-1],[0,1,-1]):
            vectors.append(i*b1+j*b2+k*b3)
        vectors = np.array(vectors)
        # Generate voronoi diagram
        vor = Voronoi(vectors)
        faces = []
        vrd = vor.ridge_dict
        for r in vrd:
            if r[0] == 0 or r[1] == 0:
                verts_in_face = np.array([vor.vertices[i] for i in vrd[r]])
                faces.append(verts_in_face)
        #faces = np.array(faces) # should be a list instead as not regular shape.

    verts = [v for vs in faces for v in vs]
    verts = np.unique(verts,axis=0)

    face_vectors = []
    for f in faces:
        face_vectors.append(np.mean(f,axis=0)) # In primitive point at face center
    if primitive == False:
        face_vectors = [2*f for f in face_vectors] # In regular, cross plane as well.

    # Order Faces.
    faces = [face[order(face,loop=loop)] for face in faces] # order based on given value of loop

    # High symmerty KPOINTS in primitive BZ (positive only)
    mid_faces = np.array([np.mean(np.unique(face,axis=0),axis=0) for face in faces])
    mid_edges = []
    for f in faces:
        for i in range(len(f)-1):
            # Do not insert point between unique vertices
            if np.isclose(np.linalg.norm(f[i]),np.linalg.norm(f[i+1])):
                mid_edges.append(np.mean([f[i],f[i+1]],axis=0))
    if mid_edges!=[]:
        mid_edges = np.unique(mid_edges,axis=0) # because faces share edges
        mid_faces = np.concatenate([mid_faces,mid_edges])
    # Bring all high symmetry points together.
    mid_all = np.concatenate([[[0,0,0]],mid_faces,verts]) # Coords
    mid_basis_all = np.array([np.linalg.solve(basis.T,v) for v in mid_all]) # Kpoints

    # Round off results
    mid_all_p    = np.round(mid_all,digits) # Coordinates
    mid_basis_p  = np.round(mid_basis_all,digits) # Relative points
    basis        = np.round(basis,digits)
    face_vectors = np.round(face_vectors,digits)
    verts        = np.round(verts,digits)
    faces        = tuple([np.round(face,digits) for face in faces])

    #Order special points near each vertex for z > 0.
    _arrs = []
    for v in verts[verts[:,2]>=0]: # Only upper hemisphere.
        _arr = []
        for i,c in enumerate(mid_all_p): # coordinates.
            _arr.append([i, np.linalg.norm(v-c)])
        _arr = np.array(_arr)
        _arr = _arr[_arr[:,1].argsort()][:,0].astype(int)
        upto = np.where(_arr == 0)[0][0]
        _arrs.append([0,*_arr[:upto]])
    one2one  = {'coords': mid_all_p ,'kpoints': mid_basis_p,'near': _arrs}
    out_dict = {'basis':basis, 'normals':face_vectors, 'vertices':verts,
                'faces':faces,'specials':one2one}
    return serializer.dict2tuple('BZ',out_dict)


# Cell
def splot_bz(bz_data, ax = None, plane=None,color='blue',fill=True,vectors=True,v3=False,vname='b',colormap='plasma',light_from=(1,1,1),alpha=0.4):
    """
    - Plots matplotlib's static figure.
    - **Parameters**
        - bz_data    : Output of `get_bz`.
        - fill       : True by defult, determines whether to fill surface of BZ or not.
        - color      : color to fill surface and stroke color.
        - vectors    : Plots basis vectors, default is True.
        - v3         : Plots 3rd vector as well. Only works in 2D and when `vectors=True`.
        - plane      : Default is None and plots 3D surface. Can take 'xy','yz','zx' to plot in 2D.
        - ax         : Auto generated by default, 2D/3D axes, auto converts in 3D on demand as well.
        - vname      : Default is `b` for reciprocal space, can set `a` for plotting cell as after `get_bz(get_bz().basis)` you get real space lattice back if `primitive=True` both times.
        - colormap  : If None, single color is applied, only works in 3D and `fill=True`. Colormap is applied along z.
        - light_from: Point from where light is thrown on BZ planes, default is (1,1,1). Only works on plane in 3D.
        - alpha    : Opacity of filling in range [0,1]. Increase for clear viewpoint.
    - **Returns**
        - ax   : Matplotlib's 2D axes if `plane=None`.
        - ax3d : Matplotlib's 2D axes if `plane` is given.

    > Tip: `splot_bz(rec_basis,primitive=True)` will plot cell in real space.
    """
    label = r"$k_{}$" if vname=='b' else "{}"
    if not ax: #For both 3D and 2D, initialize 2D axis.
        ax = sp.get_axes(figsize=(3.4,3.4)) #For better display

    _label = r'\vec{}'.format(vname) # For both

    valid_planes = 'xyzxzyx' # cylic
    if plane and plane not in valid_planes:
        raise ValueError(f"`plane` expects value in 'xyzxzyx' or None, got {plane!r}")
    elif plane and plane in valid_planes: #Project 2D
        faces = bz_data.faces
        ind = valid_planes.index(plane)
        arr = [0,1,2,0,2,1,0]
        i, j = arr[ind], arr[ind+1]
        _ = [ax.plot(f[:,i],f[:,j],color=(color),lw=0.7) for f in faces]

        if vectors:
            if v3:
                s_basis = bz_data.basis
                ijk = [0,1,2]
            else:
                s_basis = bz_data.basis[[i,j]]# Only two.
                ijk = [i,j]

            for k,y in zip(ijk,s_basis):
                l = "\n" + r" ${}_{}$".format(_label,k+1)
                ax.text(0.8*y[i],0.8*y[j], l, va='center',ha='left')
                ax.scatter([y[i]],[y[j]],color='w',s=0.0005) # Must be to scale below arrow.

            s_zero = [0 for s_b in s_basis] # either 3 or 2.
            ax.quiver(s_zero,s_zero,*s_basis.T[[i,j]],lw=0.9,color='k',angles='xy', scale_units='xy', scale=1)

        ax.set_xlabel(label.format(valid_planes[i]))
        ax.set_ylabel(label.format(valid_planes[j]))
        ax.set_aspect(1) # Must for 2D axes to show actual lengths of BZ
        return ax
    else: # Plot 3D
        if ax and ax.name == "3d":
            ax3d = ax
        else:
            pos = ax.get_position()
            fig = ax.get_figure()
            ax.remove()
            ax3d = fig.add_axes(pos,projection='3d',azim=45,elev=30,proj_type='ortho')

        if fill:
            if colormap:
                colormap = colormap if colormap in plt.colormaps() else 'viridis'
                cz = [np.mean(np.unique(f,axis=0),axis=0)[2] for f in bz_data.faces]
                levels = (cz - np.min(cz))/np.ptp(cz) # along Z.
                colors = plt.cm.get_cmap(colormap)(levels)
            else:
                colors = np.array([[*mplc.to_rgb(color)] for f in bz_data.faces]) # Single color.
            if light_from:
                intensity = bz_data.normals.dot(light_from) #Plane facing light
                intensity = (intensity - np.min(intensity) + 0.2)/np.ptp(intensity)
                intensity = intensity.clip(0,1)
                colors = np.array([i*c[:3] for i, c in zip(intensity,colors)])

            poly = Poly3DCollection(bz_data.faces,edgecolors=[color,],facecolors=colors, alpha=alpha)
            ax3d.add_collection3d(poly)
            ax3d.autoscale_view()
        else:
            _ = [ax3d.plot3D(f[:,0],f[:,1],f[:,2],color=(color),lw=0.7) for f in bz_data.faces]

        if vectors:
            for k,v in enumerate(0.35*bz_data.basis):
                ax3d.text(*v,r"${}_{}$".format(_label,k+1),va='center',ha='center')

            XYZ,UVW = [[0,0,0],[0,0,0],[0,0,0]], 0.3*bz_data.basis.T
            fancy_quiver3d(*XYZ,*UVW,C='k',L=0.7,ax=ax3d,arrowstyle="-|>",mutation_scale=7)

        l_ = np.min(bz_data.vertices,axis=0)
        h_ = np.max(bz_data.vertices,axis=0)
        ax3d.set_xlim([l_[0],h_[0]])
        ax3d.set_ylim([l_[1],h_[1]])
        ax3d.set_zlim([l_[2],h_[2]])

        # Set aspect to same as data.
        ax3d.set_box_aspect(np.ptp(bz_data.vertices,axis=0))

        ax3d.set_xlabel(label.format('x'))
        ax3d.set_ylabel(label.format('y'))
        ax3d.set_zlabel(label.format('z'))
        return ax3d

# Cell
def iplot_bz(bz_data,fill = True,color = 'rgba(168,204,216,0.4)',background = 'rgb(255,255,255)',vname = 'b', special_kpoints = True, alpha=0.4,ortho3d=True,fig=None):
    """
    - Plots interactive figure showing axes,BZ surface, special points and basis, each of which could be hidden or shown.
    - **Parameters**
        - bz_data    : Output of `get_bz`.
        - fill       : True by defult, determines whether to fill surface of BZ or not.
        - color      : color to fill surface 'rgba(168,204,216,0.4)` by default.
        - background : Plot background color, default is 'rgb(255,255,255)'.
        - vname      : Default is `b` for reciprocal space, can set `a` for plotting cell as after `get_bz(get_bz().basis)` you get real space lattice back if `primitive=True` both times.
        - special_kpoints : True by default, determines whether to plot special points or not.
        - alpha      : Opacity of BZ planes.
        - ortho3d    : Default is True, decides whether x,y,z are orthogonal or perspective.
        - fig        : (Optional) Plotly's `go.Figure`. If you want to plot on another plotly's figure, provide that.
    - **Returns**
        - fig   : plotly.graph_object's Figure instance.

    > Tip: `iplot_bz(rec_basis,primitive=True)` will plot cell in real space.
    """
    if not fig:
        fig = go.Figure()
    # Name fixing
    axes_text = ["<b>k</b><sub>x</sub>","","<b>k</b><sub>y</sub>","","<b>k</b><sub>z</sub>"]
    s_name = 'BZ'
    a_name = 'Axes'
    if vname == 'a':
        axes_text = ["<b>x</b>","","<b>y</b>","","<b>z</b>"] # Real space
        s_name = 'Lattice'
        a_name = 'RealAxes'

    # Axes
    _len = 0.5*np.mean(bz_data.basis)
    fig.add_trace(go.Scatter3d(x=[_len,0,0,0,0],y=[0,0,_len,0,0],z=[0,0,0,0,_len],
        mode='lines+text',
        text= axes_text,
        line_color='green', legendgroup=a_name,name=a_name))
    fig.add_trace(go.Cone(x=[0.95*_len,0,0],y=[0,0.95*_len,0],z=[0,0,0.95*_len],
        u=[0.2*_len,0,0],v=[0,0.2*_len,0],w=[0,0,0.2*_len],showscale=False,
        colorscale='Greens',legendgroup=a_name,name=a_name))
    # Basis
    for i,b in enumerate(bz_data.basis):
        fig.add_trace(go.Scatter3d(x=[0,b[0]], y=[0,b[1]],z=[0,b[2]],
            mode='lines+text',legendgroup="{}<sub>{}</sub>".format(vname,i+1), line_color='red',
            name="<b>{}</b><sub>{}</sub>".format(vname,i+1),text=["","<b>{}</b><sub>{}</sub>".format(vname,i+1)]))
        fig.add_trace(go.Cone(x=[0.95*b[0]],y=[0.95*b[1]],z=[0.95*b[2]],
            u=[0.2*b[0]],v=[0.2*b[1]],w=[0.2*b  [2]],showscale=False,colorscale='Reds',
            legendgroup="{}<sub>{}</sub>".format(vname,i+1),name="<b>{}</b><sub>{}</sub>".format(vname,i+1)))

    # Faces
    legend = True
    for pts in bz_data.faces:
        fig.add_trace(go.Scatter3d(x=pts[:,0], y=pts[:,1],z=pts[:,2],
            mode='lines',line_color=color, legendgroup=s_name,name=s_name,
            showlegend=legend))
        legend = False # Only first legend to show for all

    if fill:
        xc = bz_data.vertices[ConvexHull(bz_data.vertices).vertices]
        fig.add_trace(go.Mesh3d(x=xc[:, 0], y=xc[:, 1], z=xc[:, 2],
                        color=color,
                        opacity=alpha,
                        alphahull=0,
                        lighting=dict(diffuse=0.5),
                        legendgroup=s_name,name=s_name))

    # Special Points only if in reciprocal space.
    if vname == 'b' and special_kpoints:
        texts,values =[],[]
        norms = np.round(np.linalg.norm(bz_data.specials.coords,axis=1),5)
        sps = bz_data.specials
        for key,value, (i,norm) in zip(sps.kpoints, sps.coords, enumerate(norms)):
            texts.append("P{}</br>d = {}</br> Index = {}".format(key,norm,i))
            values.append([[*value,norm]])

        values = np.array(values).reshape((-1,4))
        norm_max = np.max(values[:,3])
        c_vals = np.array([int(v*255/norm_max) for v in values[:,3]])
        colors = [0 for i in c_vals]
        _unique = np.unique(np.sort(c_vals))[::-1]
        _lnp = np.linspace(0,255,len(_unique)-1)
        _u_colors = ["rgb({},0,{})".format(r,b) for b,r in zip(_lnp,_lnp[::-1])]
        for _un,_uc in zip(_unique[:-1],_u_colors):
            _index = np.where(c_vals == _un)[0]
            for _ind in _index:
                colors[_ind]=_uc

        colors[0]= "rgb(255,215,0)" # Gold color at Gamma!.
        fig.add_trace(go.Scatter3d(x=values[:,0], y=values[:,1],z=values[:,2],
                hovertext=texts,name="HSK",marker=dict(color=colors,size=4),mode='markers'))
    proj = dict(projection=dict(type = "orthographic")) if ortho3d else {}
    camera = dict(center=dict(x=0.1, y=0.1, z=0.1),**proj)
    fig.update_layout(scene_camera=camera,paper_bgcolor=background, plot_bgcolor=background,
        font_family="Times New Roman",font_size= 14,
        scene = dict(aspectmode='data',xaxis = dict(showbackground=False,visible=False),
                        yaxis = dict(showbackground=False,visible=False),
                        zaxis = dict(showbackground=False,visible=False)),
                        margin=dict(r=10, l=10,b=10, t=30))
    return fig

# Cell
def to_R3(basis,points):
    """Transforms coordinates of points (relative to non-othogonal basis) into orthogonal space.
    - **Parameters**
        - basis : Non-orthogonal basis of real or reciprocal space.
        - points: 3D points relative to basis, such as KPOINTS and Lattice Points.

    **Note**: Do not use this function if points are Cartesian or provide identity basis.
    """
    rec_basis = np.array(basis)
    points = np.array(points)
    # Formula to make coordinates from relative points.
    # kx, ky, kz = n1*b1 + n2*b2 +n3*b3
    #            = [n1, n2, n3].dot(rec_basis)
    coords = points.dot(rec_basis)
    return coords

def to_basis(basis,coords):
    """Transforms coordinates of points (relative to othogonal basis) into basis space.
    - **Parameters**
        - basis : Non-orthogonal basis of real or reciprocal space.
        - points: 3D points relative to cartesian axes, such as KPOINTS and Lattice Points.
    """
    return np.dot(np.linalg.inv(basis).T,coords.T).T

# Cell
def kpoints2bz(bz_data,kpoints,sys_info = None, primitive=False, shift = 0):
    """Brings KPOINTS inside BZ. Applies `to_R3` only if `primitive=True`.
    - **Parameters**
        - bz_data  : Output of get_bz(), make sure use same value of `primitive` there and here.
        - kpoints  : List or array of KPOINTS to transorm into BZ or R3.
        - sys_info : If given, returns kpoints using that information. Useful If kpoints are cartesian and you need to scale those.
        - primitive: Default is False and brings kpoints into regular BZ. If True, returns `to_R3()`.
        - shift    : This value is added to kpoints before any other operation, single number of list of 3 numbers for each direction.

    **Note**: If kpoints are Cartesian, provide sys_info, otherwise it will go wrong.
    """
    kpoints = np.array(kpoints) + shift
    if sys_info is not None:
        if sys_info.space_info.cartesian_kpoints:
            return to_R3(bz_data.basis,kpoints) # Already relative to basis of BZ

    if primitive:
        return to_R3(bz_data.basis,kpoints)

    cent_planes = [np.mean(np.unique(face,axis=0),axis=0) for face in bz_data.faces]

    out_coords = np.empty(np.shape(kpoints)) # To store back

    def inside(coord,cent_planes):
        _dots_ = np.max([np.dot(coord-c, c) for c in cent_planes]) #max in all planes
        #print(_dots_)
        if np.max(_dots_) > 1e-8: # Outside
            return [] # empty for comparison
        else: # Inside
            return list(coord) # Must be in list form


    from itertools import product
    for i,p in enumerate(kpoints):
        for q in product([0,1,-1],[0,1,-1],[0,1,-1]):
            # First translate, then make coords, then feed it back
            #print(q)
            pos = to_R3(bz_data.basis, p + np.array(q))
            r = inside(pos,cent_planes)
            if r:
                #print(p,'-->',r)
                out_coords[i] = r
                StopIteration

    return out_coords # These may have duplicates, apply np.unique(out_coords,axis=0). do this in surface plots

# Cell
def fix_sites(poscar_data,tol=1e-2,eqv_sites=False,translate=None):
    """Add equivalent sites to make a full data shape of lattice. Returns same data after fixing.
    - **Parameters**
        - poscar_data: Output of `export_poscar` or `export_vasprun().poscar`.
        - tol   : Tolerance value. Default is 0.01.
        - eqv_sites: If True, add sites on edges and faces. If False, just fix coordinates, i.e. `pos > 1 - tol -> pos - 1`, useful for merging poscars to make slabs.
        - translate: A number(+/-) or list of three numbers to translate in x,y,z directions.
    """
    pos = poscar_data.positions.copy()
    labels = poscar_data.labels
    out_dict = poscar_data.to_dict() # For output

    if translate and isinstance(translate,(int,float)):
        pos = pos + (translate - int(translate)) # Only translate in 0 - 1
    elif translate and len(translate) == 3:
        txyz = np.array([translate])
        pos = pos + (txyz - txyz.astype(int))

    # Fix coordinates of sites distributed on edges and faces
    pos -= (pos > (1 - tol)).astype(int) # Move towards orign for common fixing like in joining POSCARs
    out_dict['positions'] = pos
    out_dict['extra_info']['comment'] = 'Modified by Pivotpy'

    # Add equivalent sites on edges and faces if given,handle each sepecies separately
    if eqv_sites:
        new_dict, start = {}, 0
        for k,v in out_dict['unique'].items():
            vpos = pos[v]
            pos_x = vpos[((vpos[:,0] + 1) < (tol +1))] + [[1,0,0]] # Add 1 to x if within tol
            pos_y = vpos[((vpos[:,1] + 1) < (tol +1))] + [[0,1,0]] # Add 1 to y on modified and if within tol
            pos_z = vpos[((vpos[:,2] + 1) < (tol +1))] + [[0,0,1]] # Add 1 to z and if within tol
            pos_xy = vpos[((vpos[:,0:2] + 1) < (tol +1)).all(axis = 1)] + [[1,1,0]] # Add 1 to x and y and if within tol
            pos_yz = vpos[((vpos[:,1:3] + 1) < (tol +1)).all(axis = 1)] + [[0,1,1]] # Add 1 to y and z and if within tol
            pos_zx = vpos[((vpos[:,[0,2]] + 1) < (tol +1)).all(axis = 1)] + [[1,0,1]] # Add 1 to z and x and if within tol
            pos_xyz = vpos[((vpos + 1) < (tol +1)).all(axis=1)] + [[1,1,1]] # Add 1 to x,y,z and if within tol
            new_dict[k] = {'pos':np.vstack([vpos,pos_x,pos_y,pos_z,pos_xy,pos_yz,pos_zx,pos_xyz])}
            new_dict[k]['range'] = range(start,start+len(new_dict[k]['pos']))
            start += len(new_dict[k]['pos'])

        out_dict['positions'] = np.vstack([new_dict[k]['pos'] for k in new_dict.keys()])
        out_dict['unique'] = {k:new_dict[k]['range'] for k in new_dict.keys()}

    return serializer.PoscarData(out_dict)

def translate_poscar(poscar_data, offset):
    """ Translate sites of a PPSCAR. Usully a farction of integarers like 1/2,1/4 etc.
    - **Parameters**
        - poscar_data: Output of `export_poscar` or `export_vasprun().poscar`.
        - offset: A number(+/-) or list of three numbers to translate in x,y,z directions.
    """
    return fix_sites(poscar_data, translate = offset, eqv_sites=False)

def get_pairs(poscar_data, positions, r, tol=1e-3):
    """Returns a tuple of Lattice (coords,pairs), so coords[pairs] given nearest site bonds.
    - **Parameters**
        - poscar_data: Output of `export_poscar` or `export_vasprun().poscar`.
        - positions: Array(N,3) of fractional positions of lattice sites. If coordinates positions, provide unity basis.
        - r        : Cartesian distance between the pairs in units of Angstrom e.g. 1.2 -> 1.2E-10.
        - tol      : Tolerance value. Default is 10^-3.
    """
    basis = np.identity(3) if poscar_data.extra_info.cartesian else poscar_data.basis
    coords = to_R3(basis,positions)
    tree = KDTree(coords)
    inds = np.array([[*p] for p in tree.query_pairs(r,eps=tol)])
    return serializer.dict2tuple('Lattice',{'coords':coords,'pairs':inds})

def _get_bond_length(poscar_data,given=None,tol=1e-3):
    "tol is add to calculated bond length in order to fix small differences, paramater `given` in range [0,1] which is scaled to V^(1/3)."
    if given != None:
        return given*poscar_data.volume**(1/3) + tol
    else:
        basis = np.identity(3) if poscar_data.extra_info.cartesian else poscar_data.basis
        _coords = to_R3(basis,poscar_data.positions)
        _arr = sorted(np.linalg.norm(_coords[1:] - _coords[0],axis=1)) # Sort in ascending. returns list
        return np.mean(_arr[:2]) + tol if _arr else 1 #Between nearest and second nearest.

# Cell
def iplot_lat(poscar_data,sizes=10,colors = None,
              bond_length=None,tol=1e-2,bond_tol=1e-3,eqv_sites=True,
              translate = None,
              line_width=4,edge_color = 'black',
              fill=False,alpha=0.4, ortho3d=True,fig=None):
    """Interactive plot of lattice.
    - **Main Parameters**
        - poscar_data: Output of export_poscar or export_vasprun().poscar.
        - sizes      : Size of sites. Either one int/float or list equal to type of ions.
        - colors     : Sequence of colors for each type. Automatically generated if not provided.
        - bond_length: Length of bond in fractional unit [0,1]. It is scaled to V^1/3 and auto calculated if not provides.
    Other parameters just mean what they seem to be.
    """
    poscar_data = fix_sites(poscar_data,tol=tol,eqv_sites=eqv_sites,translate=translate)
    bond_length = _get_bond_length(poscar_data,given=bond_length,tol=tol)
    coords, pairs = get_pairs(poscar_data,
                        positions = poscar_data.positions,
                        r=bond_length,tol = bond_tol) # bond tolernce shpuld be smaller than cell tolernce.
    if not fig:
        fig = go.Figure()

    uelems = poscar_data.unique.to_dict()
    if not isinstance(sizes,(list,tuple,np.ndarray)):
        sizes = [sizes for elem in uelems.keys()]

    if colors and len(colors) != len(uelems.keys()):
        print('Warning: Number of colors does not match number of elements. Using default colors.')

    if (colors is None) or len(colors) != len(uelems.keys()):
        colors = [_atom_colors[elem] for elem in uelems.keys()]
        colors = ['rgb({},{},{})'.format(*[int(_c*255) for _c in c]) for c in colors]

    _colors = np.array([colors[i] for i,vs in enumerate(uelems.values()) for v in vs])
    h_text = np.array( poscar_data.labels)

    if np.any(pairs):
        coords_p = coords[pairs] #paired points
        _colors = _colors[pairs] # Colors at pairs
        coords_n = []
        colors_n = []
        for c_p, _c in zip(coords_p,_colors):
            mid = np.mean(c_p,axis=0)
            arr = np.concatenate([c_p[0],mid,mid,c_p[1]]).reshape((-1,2,3))
            coords_n = [*coords_n,*arr] # Same shape
            colors_n = [*colors_n,*_c] # same shape.

        coords_n = np.array(coords_n)
        colors_n = np.array(colors_n)

        for (i, cp),c in zip(enumerate(coords_n),colors_n):
            showlegend = True if i == 0 else False
            fig.add_trace(go.Scatter3d(
                x = cp[:,0].T,
                y = cp[:,1].T,
                z = cp[:,2].T,
                mode='lines',line_color = c,
                legendgroup='Bonds',showlegend=showlegend,
                name='Bonds',line_width=line_width))

    for (k,v),c,s in zip(uelems.items(),colors,sizes):
        fig.add_trace(go.Scatter3d(
            x = coords[v][:,0].T,
            y = coords[v][:,1].T,
            z = coords[v][:,2].T,
            mode='markers',marker_color = c,
            hovertext = h_text[v],
            line_color='rgba(1,1,1,0)',line_width=0.001,
            marker_size = s,opacity=1,name=k))

    bz = get_bz(path_pos= poscar_data.rec_basis, primitive=True)
    _ = iplot_bz(bz,fig=fig,vname='a',color=edge_color,
                fill=fill,alpha=alpha,ortho3d=ortho3d)
    return fig

# Cell
def splot_lat(poscar_data,plane = None, sizes=50,colors=None,colormap=None,
              bond_length=None,tol=1e-2,bond_tol=1e-3,eqv_sites=True,
              translate = None,
              line_width=1,edge_color=((1,0.5,0,0.4)),
              vectors=True,v3=False,
              light_from=(1,1,1),
              fill=False,alpha=0.4,ax=None,alpha_points = 0.7):
    """Static plot of lattice.
    - **Main Parameters**
        - poscar_data: Output of export_poscar or export_vasprun().poscar.
        - plane      : Plane to plot. Either 'xy','xz','yz' or None for 3D plot.
        - sizes      : Size of sites. Either one int/float or list equal to type of ions.
        - bond_length: Length of bond in fractional unit [0,1]. It is scaled to V^1/3 and auto calculated if not provides.
        - colors: Sequence of colors for each ion type. If None, automatically generated.
        - colormap: This is passed to splot_bz.
        - alpha_points: Opacity of points and bonds.
    Other parameters just mean what they seem to be.

    > Tip: Use `plt.style.use('ggplot')` for better 3D perception.
    """
    #Plane fix
    if plane and plane not in 'xyzxzyx':
        raise ValueError("plane expects in 'xyzxzyx' or None.")
    if plane:
        ind = 'xyzxzyx'.index(plane)
        arr = [0,1,2,0,2,1,0]
        ix,iy = arr[ind], arr[ind+1]
    poscar_data = fix_sites(poscar_data,tol=tol,eqv_sites=eqv_sites,translate=translate)
    bond_length = _get_bond_length(poscar_data,given=bond_length,tol=tol)
    coords, pairs = get_pairs(poscar_data,
                        positions = poscar_data.positions,
                        r=bond_length,tol = bond_tol) # bond tolernce shpuld be smaller than cell tolernce.
    bz = get_bz( poscar_data.rec_basis, primitive=True)
    ax = splot_bz(bz,ax=ax,vname='a',
                color=edge_color,colormap=colormap,
                fill=fill,alpha=alpha,plane=plane,v3=v3,
                vectors=vectors,light_from=light_from)

    uelems = poscar_data.unique.to_dict()
    if not isinstance(sizes,(list,tuple, np.ndarray)):
        sizes = [sizes for elem in uelems.keys()]

    if colors and len(colors) != len(uelems.keys()):
        print('Warning: Number of colors does not match number of elements. Using default colors.')

    if (colors is None) or len(colors) != len(uelems.keys()):
        colors = [_atom_colors[elem] for elem in uelems.keys()]

    # Before doing other stuff, create something for legend.
    for (k,v),c,s in zip(uelems.items(),colors,sizes):
        ax.scatter([],[],s=s,color=c,label=k) # Works both for 3D and 2D.

    # Now change colors and sizes to whole array size
    colors = np.array([colors[i] for i,vs in enumerate(uelems.values()) for v in vs])
    sizes = np.array([sizes[i] for i,vs in enumerate(uelems.values()) for v in vs])

    if np.any(pairs):
        coords_p = coords[pairs] #paired points
        _colors = colors[pairs] # Colors at pairs
        coords_n = []
        colors_n = []
        for c_p, _c in zip(coords_p,_colors):
            mid = np.mean(c_p,axis=0)
            arr = np.concatenate([c_p[0],mid,mid,c_p[1]]).reshape((-1,2,3))
            coords_n = [*coords_n,*arr] # Same shape
            colors_n = [*colors_n,*_c] # same shape.

        coords_n = np.array(coords_n)
        colors_n = np.array(colors_n)

        if not plane:
            _ = [ax.plot(*c.T,c=_c,lw=line_width, alpha = alpha_points) for c,_c in zip(coords_n,colors_n)]
        elif plane in 'xyzxzyx':
            _ = [ax.plot(c[:,ix],c[:,iy],c=_c,lw=line_width, alpha = alpha_points) for c,_c in zip(coords_n,colors_n)]

    if not plane:
        ax.scatter(coords[:,0],coords[:,1],coords[:,2],c = colors ,s =sizes,depthshade=False,alpha=alpha_points)
    elif plane in 'xyzxzyx':
        iz, = [i for i in range(3) if i not in (ix,iy)]
        zorder = coords[:,iz].argsort()
        if plane in 'yxzy': # Left handed
            zorder = zorder[::-1]
        ax.scatter(coords[zorder][:,ix],coords[zorder][:,iy],c = colors[zorder] ,s =sizes[zorder],zorder=3, alpha= alpha_points)

    ax.set_axis_off()
    sp.add_legend(ax)
    return ax

# Cell
def join_poscars(poscar1,poscar2,direction='z',tol=1e-2, system = None):
    """Joins two POSCARs in a given direction. In-plane lattice parameters are kept from `poscar1` and basis of `poscar2` parallel to `direction` is modified while volume is kept same.
    - **Parameters**
        - poscar1, poscar2:  Base and secondary POSCARs respectivly. Output of `export_poscar` or similar object from other functions.
        - direction: The joining direction. It is general and can join in any direction along basis. Expect one of ['a','b','c','x','y','z'].
        - tol: Default is 0.01. It is used to bring sites near 1 to near zero in order to complete sites in plane. Vasp relaxation could move a point, say at 0.00100 to 0.99800 which is not useful while merging sites.
        - system: If system is given, it is written on top of file. Otherwise, it is infered from atomic species.
    """
    _poscar1 = fix_sites(poscar1,tol=tol,eqv_sites=False)
    _poscar2 = fix_sites(poscar2,tol=tol,eqv_sites=False)
    pos1 = _poscar1.positions.copy()
    pos2 = _poscar2.positions.copy()

    s1,s2 = 0.5, 0.5 # Half length for each.
    a1,b1,c1 = np.linalg.norm(_poscar1.basis,axis=1)
    a2,b2,c2 = np.linalg.norm(_poscar2.basis,axis=1)
    basis = _poscar1.basis.copy() # Must be copied, otherwise change outside.

    # Processing in orthogonal space since a.(b x c) = abc sin(theta)cos(phi), and theta and phi are same for both.
    if direction in ['z','c']:
        c2 = (a2*b2)/(a1*b1)*c2 # Conservation of volume for right side to stretch in c-direction.
        netc = c1+c2
        s1, s2 = c1/netc, c2/netc
        pos1[:,2] = s1*pos1[:,2]
        pos2[:,2] = s2*pos2[:,2] + s1
        basis[2] = netc*basis[2]/np.linalg.norm(basis[2]) #Update 3rd vector

    elif direction in ['y','b']:
        b2 = (a2*c2)/(a1*c1)*b2 # Conservation of volume for right side to stretch in b-direction.
        netb = b1+b2
        s1, s2 = b1/netb, b2/netb
        pos1[:,1] = s1*pos1[:,1]
        pos2[:,1] = s2*pos2[:,1] + s1
        basis[1] = netb*basis[1]/np.linalg.norm(basis[1]) #Update 2nd vector

    elif direction in ['x','a']:
        a2 = (b2*c2)/(b1*c1)*a2 # Conservation of volume for right side to stretch in a-direction.
        neta = a1+a2
        s1, s2 = a1/neta, a2/neta
        pos1[:,0] = s1*pos1[:,0]
        pos2[:,0] = s2*pos2[:,0] + s1
        basis[0] = neta*basis[0]/np.linalg.norm(basis[0]) #Update 1st vector

    else:
        raise Exception("direction expects one of ['a','b','c','x','y','z']")

    scale = np.linalg.norm(basis[0])
    u1 = _poscar1.unique.to_dict()
    u2 = _poscar2.unique.to_dict()
    u_all = ({**u1,**u2}).keys() # Union of unique elements to keep track of order.


    pos_all = []
    i_all = []
    for u in u_all:
        _i_ = 0
        if u in u1.keys():
            _i_ = len(u1[u])
            pos_all = [*pos_all,*pos1[u1[u]]]
        if u in u2.keys():
            _i_ = _i_ + len(u2[u])
            pos_all = [*pos_all,*pos2[u2[u]]]
        i_all.append(_i_)

    i_all = np.cumsum([0,*i_all]) # Do it after labels
    uelems = {_u:range(i_all[i],i_all[i+1]) for i,_u in enumerate(u_all)}
    sys = system or ''.join(uelems.keys())
    iscartesian = poscar1.extra_info.cartesian or poscar2.extra_info.cartesian
    extra_info = {'cartesian':iscartesian, 'scale': scale, 'comment': 'Modified by Pivotpy'}
    out_dict = {'SYSTEM':sys,'basis':basis,'extra_info':extra_info,'positions':np.array(pos_all),'unique':uelems}
    return serializer.PoscarData(out_dict)


# Cell
def repeat_poscar(poscar_data, n, direction):
    """Repeat a given POSCAR.
    - **Parameters**
        - path_poscar: Path/to/POSCAR or `poscar` data object.
        - n: Number of repetitions.
        - direction: Direction of repetition. Can be 'x', 'y' or 'z'.
    """
    if not isinstance(n, int) and n < 2:
        raise ValueError("n must be an integer greater than 1.")
    given_poscar = poscar_data
    for i in range(1,n):
        poscar_data = join_poscars(given_poscar, poscar_data,direction = direction)
    return poscar_data

def scale_poscar(poscar_data,scale = (1,1,1),tol=1e-2):
    """Create larger/smaller cell from a given POSCAR. Can be used to repeat a POSCAR with integer scale values.
    - **Parameters**
        - poscar_data: `poscar` data object.
        - scale: Tuple of three values along (a,b,c) vectors. int or float values. If number of sites are not as expected in output, tweak `tol` instead of `scale`. You can put a minus sign with `tol` to get more sites and plus sign to reduce sites.
        - tol: It is used such that site positions are blow `1 - tol`, as 1 belongs to next cell, not previous one.
    **Tip:** scale = (2,2,2) enlarges a cell and next operation of (1/2,1/2,1/2) should bring original cell back.
    **Caveat:** A POSACR scaled with Non-integer values should only be used for visualization purposes, Not for any other opration such as making supercells, joining POSCARs.
    """
    ii, jj, kk = np.ceil(scale).astype(int) # Need int for joining.

    if tuple(scale) == (1,1,1): # No need to scale.
        return poscar_data

    if ii >= 2:
        poscar_data = repeat_poscar(poscar_data,ii,direction='x')

    if jj >= 2:
        poscar_data = repeat_poscar(poscar_data,jj,direction='y')

    if kk >= 2:
        poscar_data = repeat_poscar(poscar_data,kk,direction='z')

    if np.all([s == int(s) for s in scale]):
        return poscar_data # No need to prcess further in case of integer scaling.

    new_poscar = poscar_data.to_dict() # Update in it

    # Get clip fraction
    fi, fj, fk = scale[0]/ii, scale[1]/jj, scale[2]/kk

    # Clip at end according to scale, change length of basis as fractions.
    pos   = poscar_data.positions.copy()/np.array([fi,fj,fk]) # rescale for clip
    basis = poscar_data.basis.copy()
    for i,f in zip(range(3),[fi,fj,fk]):
        basis[i] = f*basis[i] # Basis rescale for clip

    new_poscar['basis'] = basis
    new_poscar['extra_info']['scale'] = np.linalg.norm(basis[0])
    new_poscar['extra_info']['comment'] = f'Modified by Pivotpy'

    uelems = poscar_data.unique.to_dict()
    # Minus in below for block is because if we have 0-2 then 1 belongs to next cell not original.
    positions,shift = [],0
    for key,value in uelems.items():
        s_p = pos[value] # Get positions of key
        s_p = s_p[(s_p < 1 - tol).all(axis=1)] # Get sites within tolerance

        if len(s_p) == 0:
            raise Exception(f'No sites found for {key!r}, cannot scale down. Increase scale!')

        uelems[key] = range(shift,shift + len(s_p))
        positions = [*positions,*s_p] # Pick sites
        shift += len(s_p) #Update for next element

    new_poscar['unique']    = uelems
    new_poscar['positions'] = np.array(positions)
    return serializer.PoscarData(new_poscar)

def rotate_poscar(poscar_data,angle_deg,axis_vec):
    """Rotate a given POSCAR.
    - **Parameters**
        - path_poscar: Path/to/POSCAR or `poscar` data object.
        - angle_deg: Rotation angle in degrees.
        - axis_vec : (x,y,z) of axis about which rotation takes place. Axis passes through origin.
    """
    rot = rotation(angle_deg=angle_deg,axis_vec=axis_vec)
    p_dict = poscar_data.to_dict()
    p_dict['basis'] = rot.apply(p_dict['basis']) # Rotate basis so that they are transpose
    p_dict['extra_info']['comment'] = f'Modified by Pivotpy'
    return serializer.PoscarData(p_dict)

def mirror_poscar(poscar_data, direction):
    "Mirror a POSCAR in a given direction. Sometime you need it before joining two POSCARs"
    poscar = poscar_data.to_dict() # Avoid modifying original
    idx = 'xyz'.index(direction) # Check if direction is valid
    poscar['positions'][:,idx] = 1 - poscar['positions'][:,idx] # Trick: Mirror by subtracting from 1. not by multiplying with -1.
    return serializer.PoscarData(poscar) # Return new POSCAR

def convert_poscar(poscar_data, atoms_mapping, basis_factor):
    """Convert a POSCAR to a similar structure of other atomic types or same type with strained basis.
    `atoms_mapping` is a dictionary of {old_atom: new_atom} like {'Ga':'Al'} will convert GaAs to AlAs structure.
    `basis_factor` is a scaling factor multiplied with basis vectors, single value (useful for conversion to another type)
    or list of three values to scale along (a,b,c) vectors (useful for strained structures).
    """
    poscar_data = poscar_data.to_dict() # Avoid modifying original
    poscar_data['unique'] = {atoms_mapping.get(k,k):v for k,v in poscar_data['unique'].items()} # Update types
    basis = poscar_data['basis'].copy() # Get basis to avoid modifying original

    if isinstance(basis_factor,(int,float)):
        poscar_data['basis'] = basis_factor*basis # Rescale basis
    elif isinstance(basis_factor,(list,tuple,np.ndarray)):
        if len(basis_factor) != 3:
            raise Exception('basis_factor should be a list/tuple/array of length 3')

        if np.ndim(basis_factor) != 1:
            raise Exception('basis_factor should be a list/tuple/array of 3 int/float values')

        poscar_data['basis'] = np.array([
            basis_factor[0]*basis[0],
            basis_factor[1]*basis[1],
            basis_factor[2]*basis[2]
        ])
    else:
        raise Exception('basis_factor should be a list/tuple/array of 3 int/float values, got {}'.format(type(basis_factor)))

    return serializer.PoscarData(poscar_data) # Return new POSCAR

def get_transform_matrix(poscar_data, target_basis):
    "Returns a transformation matrix that gives `target_bsis` when applied on basis of current lattice. Useful in transforming crystal structure."
    return np.matmul(target_basis, np.linalg.inv(poscar_data.basis)).round(16)

def transform_poscar(poscar_data, transform_matrix, repeat_given = [2,2,2],tol = 1e-2):
    """Transform a POSCAR with a given transformation matrix.
    Use `get_transform_matrix` to get transformation matrix from one basis to another.
    `repeat_given` is used to repeat the POSCAR before applying transformation matrix to include
    all possible sites in resultant cell."""
    new_basis = np.matmul(transform_matrix,poscar_data.basis) # Transforms basis to target basis

    if len(repeat_given) != 3:
        raise ValueError('`repeat_given` must be a list of three integers.')

    for rep in repeat_given: # Repeat if needed
        if not isinstance(rep,int) and rep >= 1:
            raise ValueError('`repeat_given` must have all values as integer >= 1.')

    for n, _dir in zip(repeat_given,'xyz'): # Repeat if needed for including atoms in new cell
        poscar_data = repeat_poscar(poscar_data,n,direction=_dir)

    center_coords = poscar_data.coords - poscar_data.coords.mean(axis=0, keepdims=True) # Center around origin
    points = to_basis(new_basis,center_coords) # Transform coordinates to new basis around origin

    new_poscar = poscar_data.to_dict() # Update in it
    new_poscar['basis'] = new_basis
    new_poscar['extra_info']['scale'] = np.linalg.norm(new_basis[0])
    new_poscar['extra_info']['comment'] = f'Transformed by Pivotpy'

    uelems = poscar_data.unique.to_dict()
    positions,shift, unique_dict = [],0, {}
    for key,value in uelems.items():
        s_p = points[value] # Get positions of key
        s_p = s_p[((s_p > -tol) & (s_p < 1 - tol)).all(axis=1)] # Get sites within tolerance

        if len(s_p) == 0:
            raise Exception(f'No sites found for {key!r}, transformation stopped! You may need to increase `repeat_given` parameter.')

        unique_dict[key] = range(shift,shift + len(s_p))
        positions = [*positions,*s_p] # Pick sites
        shift += len(s_p) #Update for next element

    new_poscar['unique']  = unique_dict
    new_poscar['positions'] = np.array(positions)
    return serializer.PoscarData(new_poscar)

def add_vaccum(poscar_data, thickness, direction, left = False):
    """Add vacuum to a POSCAR.
    - **Parameters**
        - poscar_data: `poscar` data object.
        - thickness: Thickness of vacuum in Angstrom.
        - direction: Direction of vacuum. Can be 'x', 'y' or 'z'.
        - left: If True, vacuum is added to left of sites. By default, vacuum is added to right of sites.
    """
    if direction not in 'xyz':
        raise Exception('Direction must be x, y or z.')

    poscar_dict = poscar_data.to_dict() # Avoid modifying original
    basis = poscar_dict['basis'].copy() # Copy basis to avoid modifying original
    pos = poscar_dict['positions'].copy() # Copy positions to avoid modifying original
    idx = 'xyz'.index(direction)
    norm = np.linalg.norm(basis[idx]) # Get length of basis vector
    s1, s2 = norm/(norm + thickness), thickness/(norm + thickness) # Get scaling factors
    basis[idx,:] *= (thickness + norm)/norm # Add thickness to basis
    poscar_dict['basis'] = basis
    if left:
        pos[:,idx] *= s2 # Scale down positions
        pos[:,idx] += s1 # Add vacuum to left of sites
        poscar_dict['positions'] = pos
    else:
        pos[:,idx] *= s1 # Scale down positions
        poscar_dict['positions'] = pos

    return serializer.PoscarData(poscar_dict) # Return new POSCAR


def add_atoms(poscar_data, **name_pos_kwargs): # Ga = [[1,0,0],[0,1,0],[0,0,1]] like that
    "Add atoms with a `name` to a POSCAR at given `positions` in fractional coordinates."
    positions = np.array(positions)
    raise NotImplementedError("Not implemented yet.")

def remove_atoms(poscar_data, name, positions): # {'name;:positions} is better for all cases above and below
    "Remove atoms with a `name` from a POSCAR."
    raise NotImplementedError("Not implemented yet.")

def replace_atoms(poscar_data, old_name, new_name, positions): # {'name;:positions} is better for all cases above and below
    "Replace atoms in a POSCAR with another type at given `positions`."
    raise NotImplementedError("Not implemented yet.")
