__author__ = 'ikibalin'
__version__ = "2019_11_26"

import os
import numpy
from fractions import Fraction
from pycifstar import Global

import cryspy.symcif.CONSTANTS_AND_FUNCTIONS as CONSTANTS_AND_FUNCTIONS
from typing import List, Tuple
from cryspy.common.cl_item_constr import ItemConstr
from cryspy.common.cl_loop_constr import LoopConstr


class SpaceGroupWyckoff(ItemConstr):
    """
Contains information about Wyckoff positions of a space group.
Only one site can be given for each special position but the
remainder can be generated by applying the symmetry operations
stored in _space_group_symop.operation_xyz.

Description in cif file::

 _space_group_Wyckoff.id    1  
 _space_group_Wyckoff.multiplicity 192   
 _space_group_Wyckoff.letter h   
 _space_group_Wyckoff.site_symmetry 1      
 _space_group_Wyckoff.coord_xyz x,y,z

:Attributes: - id
             - coord_xyz
             - letter
             - multiplicity
             - site_symmetry

`Reference: <https://www.iucr.org/__data/iucr/cifdic_html/2/cif_sym.dic/Cspace_group_Wyckoff.html>`_
    """
    MANDATORY_ATTRIBUTE = ("coord_xyz",)
    OPTIONAL_ATTRIBUTE = ("id", "letter", "multiplicity", "site_symmetry")
    INTERNAL_ATTRIBUTE = ("full_coord_xyz", "r", "b", "full_r", "full_b", "it_coord_xyz", "centring_type")
    PREFIX = "space_group_Wyckoff"

    def __init__(self, id=None, coord_xyz=None, letter=None, multiplicity=None, site_symmetry=None):
        super(SpaceGroupWyckoff, self).__init__(mandatory_attribute=self.MANDATORY_ATTRIBUTE,
                                                optional_attribute=self.OPTIONAL_ATTRIBUTE,
                                                internal_attribute=self.INTERNAL_ATTRIBUTE,
                                                prefix=self.PREFIX)

        self.id = id
        self.coord_xyz = coord_xyz
        self.letter = letter
        self.multiplicity = multiplicity
        self.site_symmetry = site_symmetry
        if self.is_defined:
            self.form_object

    @property
    def id(self) -> str:
        """
An arbitrary identifier that is unique to a particular Wyckoff posi-
tion.
        """
        return getattr(self, "__id")

    @id.setter
    def id(self, x):
        if ((x is None) | (x == ".")):
            x_in = None
        else:
            x_in = str(x)
        setattr(self, "__id", x_in)

    @property
    def coord_xyz(self) -> str:
        """
Coordinates of one site of a Wyckoff position expressed in terms
of its fractional coordinates (x, y, z) in the unit cell. To generate
the coordinates of all sites of this Wyckoff position, it is necessary
to multiply these coordinates by the symmetry operations stored in
_space_group_symop.operation_xyz.

Where no value is given, the assumed value is 'x,y,z'.

:Example: 'x,1/2,0' (coordinates of Wyckoff site with 2.. symmetry)
        """
        return getattr(self, "__coord_xyz")

    @coord_xyz.setter
    def coord_xyz(self, x):
        if ((x is None) | (x == ".")):
            x_in = None
        else:
            x_in = str(x)
        setattr(self, "__coord_xyz", x_in)

    @property
    def full_coord_xyz(self) -> List[str]:
        return getattr(self, "__full_coord_xyz")

    @property
    def letter(self) -> str:
        """
The Wyckoff letter associated with this position, as given in Inter-
national Tables for Crystallography Volume A. The enumeration
value \a corresponds to the Greek letter ‘α’ used in International
Tables.

:Reference: International Tables for Crystallography (2002).
            Volume A, Space-group symmetry, edited by Th. Hahn, 5th ed.
            Dordrecht: Kluwer Academic Publishers.

The data value must be one of the following:

    a b c d e f g h i j k l m n o p q r s t u v w x
    y z 
        """
        return getattr(self, "__letter")

    @letter.setter
    def letter(self, x):
        if ((x is None) | (x == ".")):
            x_in = None
        else:
            x_in = str(x)
        setattr(self, "__letter", x_in)

    @property
    def multiplicity(self) -> int:
        """
The multiplicity of this Wyckoff position as given in International
Tables Volume A. It is the number of equivalent sites per conven-
tional unit cell.

:Reference: International Tables for Crystallography (2002).
            Volume A, Space-group symmetry, edited by Th. Hahn, 5th ed.
            Dordrecht: Kluwer Academic Publishers.
        """
        return getattr(self, "__multiplicity")

    @multiplicity.setter
    def multiplicity(self, x):
        if ((x is None) | (x == ".")):
            x_in = None
        else:
            x_in = int(x)
        setattr(self, "__multiplicity", x_in)

    @property
    def site_symmetry(self) -> str:
        """
The subgroup of the space group that leaves the point fixed. It is
isomorphic to a subgroup of the point group of the space group.
The site-symmetry symbol indicates the symmetry in the symme-
try direction determined by the Hermann–Mauguin symbol of the
space group (see International Tables for Crystallography Volume
A, Section 2.2.12).

:Reference: International Tables for Crystallography (2002).
            Volume A, Space-group symmetry, edited by Th. Hahn, 5th ed.
            Dordrecht: Kluwer Academic Publishers.

:Examples: - ‘2.22’ (position 2b in space group No. 94, P4 2 2 1 2), 
           - ‘42.2’ (position 6b in space group No. 222, Pn ¯ 3n), 
           - ‘2..’ (Site symmetry for the Wyckoff position 96f in space group No.228, Fd ¯ 3c. 
              The site-symmetry group is isomorphic to the point group 2 with the twofold axis
              along one of the 100 directions.).
        """
        return getattr(self, "__site_symmetry")

    @site_symmetry.setter
    def site_symmetry(self, x):
        if ((x is None) | (x == ".")):
            x_in = None
        else:
            x_in = str(x)
        setattr(self, "__site_symmetry", x_in)

    @property
    def sg_id(self):
        """
A child of _space_group.id allowing the Wyckoff position to be
identified with a particular space group.
        """
        return getattr(self, "__sg_id")

    @sg_id.setter
    def sg_id(self, x):
        if ((x is None) | (x == ".")):
            x_in = None
        else:
            x_in = int(x)
        setattr(self, "__sg_id", x_in)

    @property
    def it_coord_xyz(self):
        return getattr(self, "__it_coord_xyz")

    def set_it_coord_xyz(self, it_coord_xyz):
        setattr(self, "__it_coord_xyz", it_coord_xyz)

    @property
    def centring_type(self):
        return getattr(self, "__centring_type")

    def set_centring_type(self, centring_type):
        setattr(self, "__centring_type", centring_type)

    @property
    def r(self):
        return getattr(self, "__r")

    @property
    def b(self):
        return getattr(self, "__b")

    @property
    def full_r(self):
        return getattr(self, "__full_r")

    @property
    def full_b(self):
        return getattr(self, "__full_b")

    @property
    def form_object(self) -> bool:
        flag = True
        coord_xyz = self.coord_xyz
        if coord_xyz is None:
            return False
        r, b = CONSTANTS_AND_FUNCTIONS.transform_string_to_r_b(coord_xyz, labels=("x", "y", "z"))
        setattr(self, "__r", r)
        setattr(self, "__b", b)

        it_coord_xyz = self.it_coord_xyz
        centring_type = self.centring_type
        shift = CONSTANTS_AND_FUNCTIONS.get_shift_by_centring_type(centring_type)

        if it_coord_xyz is not None:
            full_r, full_b = [], []
            full_coord_xyz = []
            for _coord_xyz in it_coord_xyz:
                r, b = CONSTANTS_AND_FUNCTIONS.transform_string_to_r_b(_coord_xyz, labels=("x", "y", "z"))

                for _shift in shift:
                    b_new = numpy.mod(b + numpy.array(_shift, dtype=Fraction), 1)
                    _symop = CONSTANTS_AND_FUNCTIONS.transform_r_b_to_string(r, b_new, labels=("x", "y", "z"))
                    full_coord_xyz.append(_symop)
                    full_r.append(r)
                    full_b.append(b_new)

            setattr(self, "__full_r", full_r)
            setattr(self, "__full_b", full_b)
            setattr(self, "__full_coord_xyz", full_coord_xyz)
        return flag

    def is_valid_for_fract(self, fract_x: float, fract_y: float, fract_z: float, tol=10 ** -5) -> bool:
        fract_x, fract_y, fract_z = float(fract_x), float(fract_y), float(fract_z)
        nval = int(tol ** -1)
        flag_res = False
        for r, b in zip(self.full_r, self.full_b):
            flag_res = CONSTANTS_AND_FUNCTIONS.is_good_for_mask(r, b, Fraction(fract_x).limit_denominator(nval),
                                                                Fraction(fract_y).limit_denominator(nval),
                                                                Fraction(fract_z).limit_denominator(nval))
            if flag_res:
                break
        return flag_res

    def give_default_xyz(self, xyz_0):
        one_pm_1 = numpy.array([ 1, 1, 1], dtype=int)
        one_pm_2 = numpy.array([-1, 1, 1], dtype=int)
        one_pm_3 = numpy.array([ 1,-1, 1], dtype=int)
        one_pm_4 = numpy.array([ 1, 1,-1], dtype=int)
        one_pm_5 = numpy.array([ 1,-1,-1], dtype=int)
        one_pm_6 = numpy.array([-1, 1,-1], dtype=int)
        one_pm_7 = numpy.array([-1,-1, 1], dtype=int)
        one_pm_8 = numpy.array([-1,-1,-1], dtype=int)

        b_float = self.b.astype(float)
        r_float = self.r.astype(float)
        _h = None
        np_r = numpy.array(self.full_r, dtype=Fraction)
        np_b = numpy.array(self.full_b, dtype=Fraction)
        np_b_x = np_b[:, 0]
        np_b_y = np_b[:, 1]
        np_b_z = np_b[:, 2]

        np_r_11, np_r_12, np_r_13 = np_r[:, 0, 0], np_r[:, 0, 1], np_r[:, 0, 2]
        np_r_21, np_r_22, np_r_23 = np_r[:, 1, 0], np_r[:, 1, 1], np_r[:, 1, 2]
        np_r_31, np_r_32, np_r_33 = np_r[:, 2, 0], np_r[:, 2, 1], np_r[:, 2, 2]
        nval = 10**5
        x_0 = Fraction(xyz_0[0]).limit_denominator(nval)
        y_0 = Fraction(xyz_0[1]).limit_denominator(nval)
        z_0 = Fraction(xyz_0[2]).limit_denominator(nval)
        l_ind = [(x_0, y_0, z_0), (x_0, z_0, y_0), (y_0, x_0, z_0), (y_0, z_0, x_0), (z_0, x_0, y_0), (z_0, y_0, x_0)]
        for _ind in l_ind:
            x_fr, y_fr, z_fr = _ind[0], _ind[1], _ind[2]
            r_31_x, r_12_y, r_13_z = np_r_31*x_fr, np_r_12*y_fr, np_r_13*z_fr
            r_11_x, r_22_y, r_23_z = np_r_11*x_fr, np_r_22*y_fr, np_r_23*z_fr
            r_21_x, r_32_y, r_33_z = np_r_21*x_fr, np_r_32*y_fr, np_r_33*z_fr

            val_0 = (r_11_x + r_12_y + r_13_z + np_b_x-x_0).astype(float)
            val_1 = (r_21_x + r_22_y + r_23_z + np_b_y-y_0).astype(float)
            val_2 = (r_31_x + r_32_y + r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_1 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_1:
                _h = numpy.array([float(x_fr), float(y_fr), float(z_fr)], dtype=float)
                break

            val_0 = (-r_11_x + r_12_y + r_13_z + np_b_x-x_0).astype(float)
            val_1 = (-r_21_x + r_22_y + r_23_z + np_b_y-y_0).astype(float)
            val_2 = (-r_31_x + r_32_y + r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_2 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_2:
                _h = numpy.array([-float(x_fr), float(y_fr), float(z_fr)], dtype=float)
                break

            val_0 = (r_11_x - r_12_y + r_13_z + np_b_x-x_0).astype(float)
            val_1 = (r_21_x - r_22_y + r_23_z + np_b_y-y_0).astype(float)
            val_2 = (r_31_x - r_32_y + r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_3 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_3:
                _h = numpy.array([float(x_fr), -float(y_fr), float(z_fr)], dtype=float)
                break

            val_0 = (r_11_x + r_12_y - r_13_z + np_b_x-x_0).astype(float)
            val_1 = (r_21_x + r_22_y - r_23_z + np_b_y-y_0).astype(float)
            val_2 = (r_31_x + r_32_y - r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_4 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_4:
                _h = numpy.array([float(x_fr), float(y_fr), -float(z_fr)], dtype=float)
                break

            val_0 = (r_11_x - r_12_y - r_13_z + np_b_x-x_0).astype(float)
            val_1 = (r_21_x - r_22_y - r_23_z + np_b_y-y_0).astype(float)
            val_2 = (r_31_x - r_32_y - r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_5 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_5:
                _h = numpy.array([float(x_fr), -float(y_fr), -float(z_fr)], dtype=float)
                break

            val_0 = (-r_11_x + r_12_y - r_13_z + np_b_x-x_0).astype(float)
            val_1 = (-r_21_x + r_22_y - r_23_z + np_b_y-y_0).astype(float)
            val_2 = (-r_31_x + r_32_y - r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_6 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_6:
                _h = numpy.array([-float(x_fr), float(y_fr), -float(z_fr)], dtype=float)
                break

            val_0 = (-r_11_x - r_12_y + r_13_z + np_b_x-x_0).astype(float)
            val_1 = (-r_21_x - r_22_y + r_23_z + np_b_y-y_0).astype(float)
            val_2 = (-r_31_x - r_32_y + r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_7 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_7:
                _h = numpy.array([-float(x_fr), -float(y_fr), float(z_fr)], dtype=float)
                break

            val_0 = (-r_11_x - r_12_y - r_13_z + np_b_x-x_0).astype(float)
            val_1 = (-r_21_x - r_22_y - r_23_z + np_b_y-y_0).astype(float)
            val_2 = (-r_31_x - r_32_y - r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_8 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_8:
                _h = numpy.array([-float(x_fr), -float(y_fr), -float(z_fr)], dtype=float)
                break
        if _h is None:  # not sure about this condition, but may be it is needed when x,y,z are refined
            _h = numpy.array([float(x_0), float(y_0), float(z_0)], dtype=float)
        xyz_new = (numpy.matmul(r_float, _h) + b_float) % 1
        return xyz_new


class SpaceGroupWyckoffL(LoopConstr):
    """
Contains information about Wyckoff positions of a space group.
Only one site can be given for each special position but the
remainder can be generated by applying the symmetry operations
stored in _space_group_symop.operation_xyz.

Description in cif file::

 loop_
 _space_group_Wyckoff.id
 _space_group_Wyckoff.multiplicity
 _space_group_Wyckoff.letter
 _space_group_Wyckoff.site_symmetry
 _space_group_Wyckoff.coord_xyz
    1  192   h   1      x,y,z
    2   96   g   ..2    1/4,y,-y
    3   96   f   2..    x,1/8,1/8
    4   32   b   .32    1/4,1/4,1/4

:Attributes: - id
             - coord_xyz
             - letter
             - multiplicity
             - site_symmetry


:Mandatory attribute: - id (category key, 1st)
                      - coord_xyz

:Optional attribute: - letter
                     - multiplicity
                     - site_symmetry
                     - sg_id

:Methods: - get_id_for_fract(fract_x, fract_y, fract_z)
          - get_letter_for_fract(fract_x, fract_y, fract_z)

`Reference: <https://www.iucr.org/__data/iucr/cifdic_html/2/cif_sym.dic/Cspace_group_Wyckoff.html>`_
    """
    CATEGORY_KEY = ("id",)
    ITEM_CLASS = SpaceGroupWyckoff

    def __init__(self, item=[], loop_name=""):
        super(SpaceGroupWyckoffL, self).__init__(category_key=self.CATEGORY_KEY, item_class=self.ITEM_CLASS,
                                                 loop_name=loop_name)
        self.item = item

    def get_id_for_fract(self, fract_x: float, fract_y: float, fract_z: float, tol=10 ** -5) -> str:
        l_res = []
        for _item in self.item:
            if _item.is_valid_for_fract(fract_x, fract_y, fract_z, tol):
                l_res.append((_item.id, _item.multiplicity))
        out = sorted(l_res, key=lambda x: x[1])  # sort by multiplicity

        return out[0][0]

    def get_letter_for_fract(self, fract_x: float, fract_y: float, fract_z: float, tol=10 ** -5) -> str:
        _id = self.get_id_for_fract(fract_x, fract_y, fract_z)
        res = self[_id].letter
        return res

    def get_wyckoff_for_fract(self, fract_x: float, fract_y: float, fract_z: float, tol=10 ** -5) -> str:
        _id = self.get_id_for_fract(fract_x, fract_y, fract_z)
        return self[_id]
