from typing import NoReturn
import numpy
from fractions import Fraction
from cryspy.A_functions_base.function_1_strings import \
    transform_string_to_r_b, transform_r_b_to_string
from cryspy.A_functions_base.function_2_space_group import \
    get_shift_by_centring_type, is_good_for_mask
from cryspy.B_parent_classes.cl_1_item import ItemN
from cryspy.B_parent_classes.cl_2_loop import LoopN


class SpaceGroupWyckoff(ItemN):
    """Wyckoff position.
    
    Contains information about Wyckoff positions of a space group.
    Only one site can be given for each special position but the
    remainder can be generated by applying the symmetry operations
    stored in _space_group_symop.operation_xyz.

    """
    ATTR_MANDATORY_NAMES = ("coord_xyz", )
    ATTR_MANDATORY_TYPES = (str, )
    ATTR_MANDATORY_CIF = ("coord_xyz", )

    ATTR_OPTIONAL_NAMES = ("id", "letter", "multiplicity", "site_symmetry")
    ATTR_OPTIONAL_TYPES = (str, str, int, str)
    ATTR_OPTIONAL_CIF = ("id", "letter", "multiplicity", "site_symmetry")

    ATTR_NAMES = ATTR_MANDATORY_NAMES + ATTR_OPTIONAL_NAMES
    ATTR_TYPES = ATTR_MANDATORY_TYPES + ATTR_OPTIONAL_TYPES
    ATTR_CIF = ATTR_MANDATORY_CIF + ATTR_OPTIONAL_CIF

    ATTR_INT_NAMES = ("full_coord_xyz", "r", "b", "full_r", "full_b", 
                      "centring_type")
    ATTR_INT_PROTECTED_NAMES = ("it_coord_xyz", )

    # parameters considered are refined parameters
    ATTR_REF = ()
    ATTR_SIGMA = tuple([f"{_h:}_sigma" for _h in ATTR_REF])
    ATTR_CONSTR_FLAG = tuple([f"{_h:}_constraint" for _h in ATTR_REF])
    ATTR_REF_FLAG = tuple([f"{_h:}_refinement" for _h in ATTR_REF])

    # constraints on the parameters
    D_CONSTRAINTS = {}

    # default values for the parameters
    D_DEFAULT = {}
    for key in ATTR_SIGMA:
        D_DEFAULT[key] = 0.
    for key in (ATTR_CONSTR_FLAG + ATTR_REF_FLAG):
        D_DEFAULT[key] = False

    PREFIX = "space_group_Wyckoff"

    def __init__(self, **kwargs) -> NoReturn:
        super(SpaceGroupWyckoff, self).__init__()

        # defined for any integer and float parameters
        D_MIN = {}

        # defined for ani integer and float parameters
        D_MAX = {}

        self.__dict__["D_MIN"] = D_MIN
        self.__dict__["D_MAX"] = D_MAX
        for key, attr in self.D_DEFAULT.items():
            setattr(self, key, attr)
        for key, attr in kwargs.items():
            setattr(self, key, attr)

    def form_object(self) -> NoReturn:
        coord_xyz = self.coord_xyz
        if coord_xyz is None:
            return False
        r, b = transform_string_to_r_b(coord_xyz, labels=("x", "y", "z"))
        self.__dict__["r"] = r
        self.__dict__["b"] = b

        it_coord_xyz = self.it_coord_xyz
        centring_type = self.centring_type
        shift = get_shift_by_centring_type(centring_type)

        if it_coord_xyz is not None:
            full_r, full_b = [], []
            full_coord_xyz = []
            for _coord_xyz in it_coord_xyz:
                r, b = transform_string_to_r_b(_coord_xyz, labels=("x", "y", "z"))

                for _shift in shift:
                    b_new = numpy.mod(b + numpy.array(_shift, dtype=Fraction), 1)
                    _symop = transform_r_b_to_string(r, b_new, labels=("x", "y", "z"))
                    full_coord_xyz.append(_symop)
                    full_r.append(r)
                    full_b.append(b_new)

            self.__dict__["full_r"] = full_r
            self.__dict__["full_b"] = full_b
            self.__dict__["full_coord_xyz"] = full_coord_xyz

    def is_valid_for_fract(self, fract_x: float, fract_y: float,
                           fract_z: float, tol=10 ** -5) -> bool:
        fract_x, fract_y = float(fract_x), float(fract_y)
        fract_z = float(fract_z)
        nval = int(tol ** -1)
        flag_res = False
        for r, b in zip(self.full_r, self.full_b):
            flag_res = is_good_for_mask(
                r, b, Fraction(fract_x).limit_denominator(nval),
                Fraction(fract_y).limit_denominator(nval),
                Fraction(fract_z).limit_denominator(nval))
            if flag_res:
                break
        return flag_res

    def give_default_xyz(self, xyz_0):
        one_pm_1 = numpy.array([ 1, 1, 1], dtype=int)
        one_pm_2 = numpy.array([-1, 1, 1], dtype=int)
        one_pm_3 = numpy.array([ 1,-1, 1], dtype=int)
        one_pm_4 = numpy.array([ 1, 1,-1], dtype=int)
        one_pm_5 = numpy.array([ 1,-1,-1], dtype=int)
        one_pm_6 = numpy.array([-1, 1,-1], dtype=int)
        one_pm_7 = numpy.array([-1,-1, 1], dtype=int)
        one_pm_8 = numpy.array([-1,-1,-1], dtype=int)

        b_float = self.b.astype(float)
        r_float = self.r.astype(float)
        _h = None
        np_r = numpy.array(self.full_r, dtype=Fraction)
        np_b = numpy.array(self.full_b, dtype=Fraction)
        np_b_x = np_b[:, 0]
        np_b_y = np_b[:, 1]
        np_b_z = np_b[:, 2]

        np_r_11, np_r_12, np_r_13 = np_r[:, 0, 0], np_r[:, 0, 1], np_r[:, 0, 2]
        np_r_21, np_r_22, np_r_23 = np_r[:, 1, 0], np_r[:, 1, 1], np_r[:, 1, 2]
        np_r_31, np_r_32, np_r_33 = np_r[:, 2, 0], np_r[:, 2, 1], np_r[:, 2, 2]
        nval = 10**5
        x_0 = Fraction(xyz_0[0]).limit_denominator(nval)
        y_0 = Fraction(xyz_0[1]).limit_denominator(nval)
        z_0 = Fraction(xyz_0[2]).limit_denominator(nval)
        l_ind = [(x_0, y_0, z_0), (x_0, z_0, y_0), (y_0, x_0, z_0), (y_0, z_0, x_0), (z_0, x_0, y_0), (z_0, y_0, x_0)]
        for _ind in l_ind:
            x_fr, y_fr, z_fr = _ind[0], _ind[1], _ind[2]
            r_31_x, r_12_y, r_13_z = np_r_31*x_fr, np_r_12*y_fr, np_r_13*z_fr
            r_11_x, r_22_y, r_23_z = np_r_11*x_fr, np_r_22*y_fr, np_r_23*z_fr
            r_21_x, r_32_y, r_33_z = np_r_21*x_fr, np_r_32*y_fr, np_r_33*z_fr

            val_0 = (r_11_x + r_12_y + r_13_z + np_b_x-x_0).astype(float)
            val_1 = (r_21_x + r_22_y + r_23_z + np_b_y-y_0).astype(float)
            val_2 = (r_31_x + r_32_y + r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_1 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_1:
                _h = numpy.array([float(x_fr), float(y_fr), float(z_fr)], dtype=float)
                break

            val_0 = (-r_11_x + r_12_y + r_13_z + np_b_x-x_0).astype(float)
            val_1 = (-r_21_x + r_22_y + r_23_z + np_b_y-y_0).astype(float)
            val_2 = (-r_31_x + r_32_y + r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_2 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_2:
                _h = numpy.array([-float(x_fr), float(y_fr), float(z_fr)], dtype=float)
                break

            val_0 = (r_11_x - r_12_y + r_13_z + np_b_x-x_0).astype(float)
            val_1 = (r_21_x - r_22_y + r_23_z + np_b_y-y_0).astype(float)
            val_2 = (r_31_x - r_32_y + r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_3 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_3:
                _h = numpy.array([float(x_fr), -float(y_fr), float(z_fr)], dtype=float)
                break

            val_0 = (r_11_x + r_12_y - r_13_z + np_b_x-x_0).astype(float)
            val_1 = (r_21_x + r_22_y - r_23_z + np_b_y-y_0).astype(float)
            val_2 = (r_31_x + r_32_y - r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_4 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_4:
                _h = numpy.array([float(x_fr), float(y_fr), -float(z_fr)], dtype=float)
                break

            val_0 = (r_11_x - r_12_y - r_13_z + np_b_x-x_0).astype(float)
            val_1 = (r_21_x - r_22_y - r_23_z + np_b_y-y_0).astype(float)
            val_2 = (r_31_x - r_32_y - r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_5 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_5:
                _h = numpy.array([float(x_fr), -float(y_fr), -float(z_fr)], dtype=float)
                break

            val_0 = (-r_11_x + r_12_y - r_13_z + np_b_x-x_0).astype(float)
            val_1 = (-r_21_x + r_22_y - r_23_z + np_b_y-y_0).astype(float)
            val_2 = (-r_31_x + r_32_y - r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_6 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_6:
                _h = numpy.array([-float(x_fr), float(y_fr), -float(z_fr)], dtype=float)
                break

            val_0 = (-r_11_x - r_12_y + r_13_z + np_b_x-x_0).astype(float)
            val_1 = (-r_21_x - r_22_y + r_23_z + np_b_y-y_0).astype(float)
            val_2 = (-r_31_x - r_32_y + r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_7 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_7:
                _h = numpy.array([-float(x_fr), -float(y_fr), float(z_fr)], dtype=float)
                break

            val_0 = (-r_11_x - r_12_y - r_13_z + np_b_x-x_0).astype(float)
            val_1 = (-r_21_x - r_22_y - r_23_z + np_b_y-y_0).astype(float)
            val_2 = (-r_31_x - r_32_y - r_33_z + np_b_z-z_0).astype(float)
            flag_2d_0 = numpy.isclose(numpy.mod(val_0, 1), 0.)
            flag_2d_1 = numpy.isclose(numpy.mod(val_1, 1), 0.)
            flag_2d_2 = numpy.isclose(numpy.mod(val_2, 1), 0.)
            flag_8 = numpy.any(flag_2d_0*flag_2d_1*flag_2d_2, axis=0)
            if flag_8:
                _h = numpy.array([-float(x_fr), -float(y_fr), -float(z_fr)], dtype=float)
                break
        if _h is None:  # not sure about this condition, but may be it is needed when x,y,z are refined
            _h = numpy.array([float(x_0), float(y_0), float(z_0)], dtype=float)
        xyz_new = (numpy.matmul(r_float, _h) + b_float) % 1
        return xyz_new

class SpaceGroupWyckoffL(LoopN):
    """Wyckoff positions.

    Contains information about Wyckoff positions of a space group.
    Only one site can be given for each special position but the
    remainder can be generated by applying the symmetry operations
    stored in _space_group_symop.operation_xyz.

    """
    ITEM_CLASS = SpaceGroupWyckoff
    ATTR_INDEX = "id"
    def __init__(self, loop_name = None) -> NoReturn:
        super(SpaceGroupWyckoffL, self).__init__()
        self.__dict__["items"] = []
        self.__dict__["loop_name"] = loop_name

    def get_id_for_fract(self, fract_x: float, fract_y: float, fract_z: float, tol=10 ** -5) -> str:
        l_res = []
        for _item in self.items:
            if _item.is_valid_for_fract(fract_x, fract_y, fract_z, tol):
                l_res.append((_item.id, _item.multiplicity))
        out = sorted(l_res, key=lambda x: x[1])  # sort by multiplicity
        return out[0][0]

    def get_letter_for_fract(self, fract_x: float, fract_y: float, fract_z: float, tol=10 ** -5) -> str:
        _id = self.get_id_for_fract(fract_x, fract_y, fract_z)
        res = self[_id].letter
        return res

    def get_wyckoff_for_fract(self, fract_x: float, fract_y: float, fract_z: float, tol=10 ** -5) -> str:
        _id = self.get_id_for_fract(fract_x, fract_y, fract_z)
        return self[_id]
    
# s_cont = """
#   loop_
#   _space_group_wyckoff.id
#   _space_group_wyckoff.multiplicity
#   _space_group_wyckoff.letter
#   _space_group_wyckoff.site_symmetry
#   _space_group_wyckoff.coord_xyz
#     1  192   h   1      x,y,z
#     2   96   g   ..2    1/4,y,-y
#     3   96   f   2..    x,1/8,1/8
#     4   32   b   .32    1/4,1/4,1/4
# """

# obj = SpaceGroupWyckoffL.from_cif(s_cont)
# print(obj, end="\n\n")
# print(obj["1"], end="\n\n")
