from __future__ import annotations

import logging
from typing import TYPE_CHECKING, List, Union

import numpy as np

from ada.base.physical_objects import BackendGeom
from ada.concepts.bounding_box import BoundingBox
from ada.concepts.curves import CurvePoly, CurveRevolve
from ada.concepts.points import Node
from ada.concepts.transforms import Placement
from ada.config import Settings
from ada.core.utils import Counter, roundoff
from ada.core.vector_utils import (
    angle_between,
    calc_yvec,
    calc_zvec,
    unit_vector,
    vector_length,
)
from ada.materials import Material
from ada.materials.utils import get_material
from ada.sections import Section
from ada.sections.utils import get_section

if TYPE_CHECKING:
    from OCC.Core.TopoDS import TopoDS_Shape

    from ada.concepts.connections import JointBase
    from ada.fem.elements import HingeProp

section_counter = Counter(1)
material_counter = Counter(1)


class Justification:
    NA = "neutral axis"
    TOS = "top of steel"


class Beam(BackendGeom):
    """
    The base Beam object

    :param n1: Start position of beam. List or Node object
    :param n2: End position of beam. List or Node object
    :param sec: Section definition. Str or Section Object
    :param mat: Material. Str or Material object. String: ['S355' & 'S420'] (default is 'S355' if None is parsed)
    :param name: Name of beam
    :param tap: Tapering of beam. Str or Section object
    :param jusl: Justification of Beam centreline
    :param curve: Curve
    """

    JUSL_TYPES = Justification

    def __init__(
        self,
        name,
        n1=None,
        n2=None,
        sec: Union[str, Section] = None,
        mat: Union[str, Material] = None,
        tap: Union[str, Section] = None,
        jusl=JUSL_TYPES.NA,
        up=None,
        angle=0.0,
        curve: Union[CurvePoly, CurveRevolve] = None,
        e1=None,
        e2=None,
        colour=None,
        parent=None,
        metadata=None,
        ifc_geom=None,
        opacity=None,
        units="m",
        ifc_elem=None,
        guid=None,
        placement=Placement(),
    ):
        super().__init__(name, metadata=metadata, units=units, guid=guid, ifc_elem=ifc_elem, placement=placement)
        if curve is not None:
            curve.parent = self
            if type(curve) is CurvePoly:
                n1 = curve.points3d[0]
                n2 = curve.points3d[-1]
            elif type(curve) is CurveRevolve:
                n1 = curve.p1
                n2 = curve.p2
            else:
                raise ValueError(f'Unsupported curve type "{type(curve)}"')

        self.colour = colour
        self._curve = curve
        self._n1 = n1 if type(n1) is Node else Node(n1[:3], units=units)
        self._n2 = n2 if type(n2) is Node else Node(n2[:3], units=units)
        self._jusl = jusl

        self._connected_to = []
        self._connected_end1 = None
        self._connected_end2 = None
        self._tos = None
        self._e1 = e1
        self._e2 = e2
        self._hinge_prop = None

        self._parent = parent
        self._bbox = None

        # Section and Material setup
        self._section, self._taper = get_section(sec)
        self._section.refs.append(self)
        self._taper.refs.append(self)
        self._material = get_material(mat)
        self._material.refs.append(self)

        if tap is not None:
            self._taper, _ = get_section(tap)

        self._section.parent = self
        self._taper.parent = self

        # Define orientations

        xvec = unit_vector(self.n2.p - self.n1.p)
        tol = 1e-3
        zvec = calc_zvec(xvec)
        gup = np.array(zvec)

        if up is None:
            if angle != 0.0 and angle is not None:
                from pyquaternion import Quaternion

                my_quaternion = Quaternion(axis=xvec, degrees=angle)
                rot_mat = my_quaternion.rotation_matrix
                up = np.array([roundoff(x) if abs(x) != 0.0 else 0.0 for x in np.matmul(gup, np.transpose(rot_mat))])
            else:
                up = np.array([roundoff(x) if abs(x) != 0.0 else 0.0 for x in gup])
            yvec = calc_yvec(xvec, up)
        else:
            if (len(up) == 3) is False:
                raise ValueError("Up vector must be length 3")
            if vector_length(xvec - up) < tol:
                raise ValueError("The assigned up vector is too close to your beam direction")
            yvec = calc_yvec(xvec, up)
            # TODO: Fix improper calculation of angle (e.g. xvec = [1,0,0] and up = [0, 1,0] should be 270?
            rad = angle_between(up, yvec)
            angle = np.rad2deg(rad)
            up = np.array(up)

        # lup = np.cross(xvec, yvec)
        self._xvec = xvec
        self._yvec = np.array([roundoff(x) for x in yvec])
        self._up = up
        self._angle = angle

        self._ifc_geom = ifc_geom
        self._opacity = opacity

    def get_outer_points(self):
        from itertools import chain

        from ada.core.vector_utils import local_2_global_nodes

        section_profile = self.section.get_section_profile(False)
        if section_profile.disconnected:
            ot = list(chain.from_iterable([x.points2d for x in section_profile.outer_curve_disconnected]))
        else:
            ot = section_profile.outer_curve.points2d

        yv = self.yvec
        xv = self.xvec
        p1 = self.n1.p
        p2 = self.n2.p

        nodes_p1 = local_2_global_nodes(ot, p1, yv, xv)
        nodes_p2 = local_2_global_nodes(ot, p2, yv, xv)

        return nodes_p1, nodes_p2

    def _generate_ifc_elem(self):
        from ada.ifc.write.write_beams import write_ifc_beam

        return write_ifc_beam(self)

    def calc_con_points(self, point_tol=Settings.point_tol):
        from ada.core.vector_utils import sort_points_by_dist

        a = self.n1.p
        b = self.n2.p
        points = [tuple(con.centre) for con in self.connected_to]

        def is_mem_eccentric(mem, centre):
            is_ecc = False
            end = None
            if point_tol < vector_length(mem.n1.p - centre) < mem.length * 0.9:
                is_ecc = True
                end = mem.n1.p
            if point_tol < vector_length(mem.n2.p - centre) < mem.length * 0.9:
                is_ecc = True
                end = mem.n2.p
            return is_ecc, end

        if len(self.connected_to) == 1:
            con = self.connected_to[0]
            if con.main_mem == self:
                for m in con.beams:
                    if m != self:
                        is_ecc, end = is_mem_eccentric(m, con.centre)
                        if is_ecc:
                            logging.info(f'do something with end "{end}"')
                            points.append(tuple(end))

        midpoints = []
        prev_p = None
        for p in sort_points_by_dist(a, points):
            p = np.array(p)
            bmlen = self.length
            vlena = vector_length(p - a)
            vlenb = vector_length(p - b)

            if prev_p is not None:
                if vector_length(p - prev_p) < point_tol:
                    continue

            if vlena < point_tol:
                self._connected_end1 = self.connected_to[points.index(tuple(p))]
                prev_p = p
                continue

            if vlenb < point_tol:
                self._connected_end2 = self.connected_to[points.index(tuple(p))]
                prev_p = p
                continue

            if vlena > bmlen or vlenb > bmlen:
                prev_p = p
                continue

            midpoints += [p]
            prev_p = p

        return midpoints

    @property
    def units(self):
        return self._units

    @units.setter
    def units(self, value):
        if self._units != value:
            self.n1.units = value
            self.n2.units = value
            self.section.units = value
            self.material.units = value
            for pen in self.penetrations:
                pen.units = value
            self._units = value

    @property
    def section(self) -> Section:
        return self._section

    @section.setter
    def section(self, value):
        self._section = value

    @property
    def taper(self) -> Section:
        return self._taper

    @taper.setter
    def taper(self, value):
        self._taper = value

    @property
    def material(self) -> Material:
        return self._material

    @material.setter
    def material(self, value):
        self._material = value

    @property
    def member_type(self):
        from ada.core.vector_utils import is_parallel

        xvec = self.xvec
        if is_parallel(xvec, [0.0, 0.0, 1.0], tol=1e-1):
            mtype = "Column"
        elif xvec[2] == 0.0:
            mtype = "Girder"
        else:
            mtype = "Brace"

        return mtype

    @property
    def connected_to(self) -> List["JointBase"]:
        return self._connected_to

    @property
    def connected_end1(self):
        return self._connected_end1

    @property
    def connected_end2(self):
        return self._connected_end2

    @property
    def length(self) -> float:
        """Returns the length of the beam"""
        p1 = self.n1.p
        p2 = self.n2.p

        if self.e1 is not None:
            p1 += self.e1
        if self.e2 is not None:
            p2 += self.e2
        return vector_length(p2 - p1)

    @property
    def jusl(self):
        """Justification line"""
        return self._jusl

    @property
    def ori(self):
        """Get the xvector, yvector and zvector of a given beam"""

        return self.xvec, self.yvec, self.up

    @property
    def xvec(self) -> np.ndarray:
        """Local X-vector"""
        return self._xvec

    @property
    def yvec(self) -> np.ndarray:
        """Local Y-vector"""
        return self._yvec

    @property
    def up(self) -> np.ndarray:
        return self._up

    @property
    def n1(self) -> Node:
        return self._n1

    @n1.setter
    def n1(self, value):
        self._n1 = value

    @property
    def n2(self) -> Node:
        return self._n2

    @n2.setter
    def n2(self, value):
        self._n2 = value

    @property
    def bbox(self) -> BoundingBox:
        """Bounding Box of beam"""
        if self._bbox is None:
            self._bbox = BoundingBox(self)

        return self._bbox

    @property
    def e1(self) -> np.ndarray:
        return self._e1

    @e1.setter
    def e1(self, value):
        self._e1 = np.array(value)

    @property
    def e2(self) -> np.ndarray:
        return self._e2

    @e2.setter
    def e2(self, value):
        self._e2 = np.array(value)

    @property
    def hinge_prop(self) -> "HingeProp":
        return self._hinge_prop

    @hinge_prop.setter
    def hinge_prop(self, value: "HingeProp"):
        value.beam_ref = self
        if value.end1 is not None:
            value.end1.concept_node = self.n1
        if value.end2 is not None:
            value.end2.concept_node = self.n2
        self._hinge_prop = value

    @property
    def opacity(self):
        return self._opacity

    @property
    def curve(self) -> CurvePoly:
        return self._curve

    @property
    def line(self):
        from ada.occ.utils import make_wire_from_points

        # midpoints = self.calc_con_points()
        # points = [self.n1.p]
        # points += midpoints
        # points += [self.n2.p]

        points = [self.n1.p, self.n2.p]

        return make_wire_from_points(points)

    @property
    def shell(self) -> "TopoDS_Shape":
        from ada.occ.utils import apply_penetrations, create_beam_geom

        geom = apply_penetrations(create_beam_geom(self, False), self.penetrations)

        return geom

    @property
    def solid(self) -> "TopoDS_Shape":
        from ada.occ.utils import apply_penetrations, create_beam_geom

        geom = apply_penetrations(create_beam_geom(self, True), self.penetrations)

        return geom

    def __hash__(self):
        return hash(self.guid)

    def __eq__(self, other: Beam):
        for key, val in self.__dict__.items():
            if "parent" in key or key in ["_ifc_settings", "_ifc_elem"]:
                continue
            oval = other.__dict__[key]

            if type(val) in (list, tuple, np.ndarray):
                if False in [x == y for x, y in zip(oval, val)]:
                    return False
            try:
                res = oval != val
            except ValueError as e:
                logging.error(e)
                return True

            if res is True:
                return False

        return True

    def __repr__(self):
        p1s = self.n1.p.tolist()
        p2s = self.n2.p.tolist()
        secn = self.section.sec_str
        matn = self.material.name
        return f'Beam("{self.name}", {p1s}, {p2s}, {secn}, {matn})'
