from libsg.scene_types import BBox3D

import numpy as np
import pybullet as p
from collections import namedtuple
from itertools import groupby

# handle to a simulated rigid body
Body = namedtuple('Body', ['id', 'bid', 'vid', 'cid', 'static'])

# a body-body contact record
Contact = namedtuple(
    'Contact', ['flags', 'idA', 'idB', 'linkIndexA', 'linkIndexB',
                'positionOnAInWS', 'positionOnBInWS',
                'contactNormalOnBInWS', 'distance', 'normalForce']
)

# ray intersection record
Intersection = namedtuple('Intersection', ['id', 'linkIndex', 'ray_fraction', 'position', 'normal'])


class Simulator:
    def __init__(self, mode='direct', verbose=False, use_y_up=False):
        self._mode = mode
        self._verbose = verbose
        self._obj_id_to_body = {}
        self._bid_to_body = {}
        self._pid = None
        self._use_y_up = use_y_up # assume z up if false
        self._use_single_collection_filter_group = False
        self.connect()

    def connect(self):
        # disconnect and kill existing servers
        if self._pid:
            p.disconnect(physicsClientId=self._pid)
            self._pid = None

        # reconnect to appropriate server type
        if self._mode == 'gui':
            self._pid = p.connect(p.GUI)
        elif self._mode == 'direct':
            self._pid = p.connect(p.DIRECT)
        else:
            raise RuntimeError(f'Unknown simulator server mode={self._mode}')

        # reset and initialize gui if needed
        p.resetSimulation(physicsClientId=self._pid)
        if self._mode == 'gui':
            if self._use_y_up:
                p.configureDebugVisualizer(p.COV_ENABLE_Y_AXIS_UP, 1, physicsClientId=self._pid)
            p.configureDebugVisualizer(p.COV_ENABLE_RGB_BUFFER_PREVIEW, 0, physicsClientId=self._pid)
            p.configureDebugVisualizer(p.COV_ENABLE_DEPTH_BUFFER_PREVIEW, 0, physicsClientId=self._pid)
            p.configureDebugVisualizer(p.COV_ENABLE_SEGMENTATION_MARK_PREVIEW, 0, physicsClientId=self._pid)
            # disable rendering during loading -> faster
            p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 0, physicsClientId=self._pid)

    def run(self):
        if self._mode == 'gui':
            # infinite ground plane and gravity
            # plane_cid = p.createCollisionShape(p.GEOM_PLANE, planeNormal=[0, 1, 0])
            # plane_bid = p.createMultiBody(baseMass=0, baseCollisionShapeIndex=plane_cid)
            if self._use_y_up:
                p.setGravity(0, -10, 0, physicsClientId=self._pid)
            else:
                p.setGravity(0, 0, -10, physicsClientId=self._pid)
            p.setRealTimeSimulation(1, physicsClientId=self._pid)
            self.set_gui_rendering(enabled=True)
            while True:
                contacts = self.get_contacts(include_collision_with_static=True)
                if self._verbose:
                    print(f'#contacts={len(contacts)}, contact_pairs={contacts.keys()}')
        else:
            self.step()
            contacts = self.get_contacts(include_collision_with_static=True)
            if self._verbose:
                for (k, c) in contacts.items():
                    cp = self.get_closest_point(obj_id_a=c.idA, obj_id_b=c.idB)
                    print(f'contact pair={k} record={cp}')
                print(f'#contacts={len(contacts)}, contact_pairs={contacts.keys()}')

    def set_gui_rendering(self, enabled):
        if not self._mode == 'gui':
            return False
        center = np.array([0.0, 0.0, 0.0])
        num_obj = 0
        for obj_id in self._obj_id_to_body.keys():
            pos, _ = self.get_state(obj_id)
            num_obj += 1
            center += pos
        center /= num_obj
        p.resetDebugVisualizerCamera(cameraDistance=10.0,
                                     cameraYaw=45.0,
                                     cameraPitch=-30.0,
                                     cameraTargetPosition=center,
                                     physicsClientId=self._pid)
        p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 1 if enabled else 0, physicsClientId=self._pid)
        return enabled

    def add_mesh(self, obj_id, obj_file, transform, vis_mesh_file=None, static=False, concave=False):
        if concave:
            cid = p.createCollisionShape(p.GEOM_MESH, fileName=obj_file, meshScale=transform.scale,
                                         flags=p.GEOM_FORCE_CONCAVE_TRIMESH, physicsClientId=self._pid)
        else:
            cid = p.createCollisionShape(p.GEOM_MESH, fileName=obj_file, meshScale=transform.scale,
                                         physicsClientId=self._pid)
        vid = -1
        if vis_mesh_file:
            vid = p.createVisualShape(p.GEOM_MESH, fileName=vis_mesh_file, meshScale=transform.scale,
                                      physicsClientId=self._pid)
        mass = 0 if static else 1
        bid = p.createMultiBody(baseMass=mass,
                                baseCollisionShapeIndex=cid,
                                baseVisualShapeIndex=vid,
                                basePosition=transform.translation,
                                baseOrientation=transform.rotation,
                                physicsClientId=self._pid)
        body = Body(id=obj_id, bid=bid, vid=vid, cid=cid, static=static)
        self._obj_id_to_body[obj_id] = body
        self._bid_to_body[bid] = body
        if self._use_single_collection_filter_group:
            self._set_single_collection_filter_group_for_body(body)
        return body

    def add_box(self, obj_id, half_extents, transform, static=False):
        cid = p.createCollisionShape(p.GEOM_BOX, halfExtents=half_extents, physicsClientId=self._pid)
        mass = 0 if static else 1
        bid = p.createMultiBody(baseMass=mass,
                                baseCollisionShapeIndex=cid,
                                basePosition=transform.translation,
                                baseOrientation=transform.rotation,
                                physicsClientId=self._pid)
        body = Body(id=obj_id, bid=bid, vid=-1, cid=cid, static=static)
        self._obj_id_to_body[obj_id] = body
        self._bid_to_body[bid] = body
        if self._use_single_collection_filter_group:
            self._set_single_collection_filter_group_for_body(body)
        return body

    def remove(self, obj_id):
        body = self._obj_id_to_body[obj_id]
        p.removeBody(bodyUniqueId=body.bid, physicsClientId=self._pid)
        del self._obj_id_to_body[obj_id]
        del self._bid_to_body[body.bid]

    def set_state(self, obj_id, position, rot_q):
        body = self._obj_id_to_body[obj_id]
        p.resetBasePositionAndOrientation(bodyUniqueId=body.bid, posObj=position, ornObj=rot_q,
                                          physicsClientId=self._pid)

    def get_state(self, obj_id):
        body = self._obj_id_to_body[obj_id]
        pos, rot_q = p.getBasePositionAndOrientation(bodyUniqueId=body.bid, physicsClientId=self._pid)
        return pos, rot_q

    def step(self):
        p.stepSimulation(physicsClientId=self._pid)

    def reset(self):
        p.resetSimulation(physicsClientId=self._pid)
        self._obj_id_to_body = {}
        self._bid_to_body = {}

    def get_aabb(self, obj_id):
        bid = self._obj_id_to_body[obj_id].bid
        min, max = p.getAABB(bodyUniqueId=bid, physicsClientId=self._pid)
        return BBox3D.from_min_max(min, max)

    def get_aabb_all(self):
        points = []
        for obj_id, body in self._obj_id_to_body.items():
            min, max = p.getAABB(bodyUniqueId=body.bid, physicsClientId=self._pid)
            points.append(min)
            points.append(max)
        return BBox3D.from_point_list(points)

    def get_closest_point(self, obj_id_a, obj_id_b, max_distance=np.inf):
        """
        Return record with distance between closest points between pair of nodes if within max_distance or None.
        """
        bid_a = self._obj_id_to_body[obj_id_a].bid
        bid_b = self._obj_id_to_body[obj_id_b].bid
        cps = p.getClosestPoints(bodyA=bid_a, bodyB=bid_b, distance=max_distance, physicsClientId=self._pid)
        cp = None
        if len(cps) > 0:
            closest_points = self._convert_contacts(cps)
            cp = min(closest_points, key=lambda x: x.distance)
        del cps  # NOTE force garbage collection of pybullet objects
        return cp

    def get_closest_points(self, obj_id_a=None, obj_id_b=None, max_distance=np.inf):
        """
        Return record with distance between closest points between pair of nodes if within max_distance or None.
        """
        bid_a = self._obj_id_to_body[obj_id_a].bid if obj_id_a else -1
        bid_b = self._obj_id_to_body[obj_id_b].bid if obj_id_b else -1
        cps = p.getClosestPoints(bodyA=bid_a, bodyB=bid_b, distance=max_distance, physicsClientId=self._pid)
        if len(cps) > 0:
            closest_points = self._convert_contacts(cps)
            return closest_points
        else:
            return []

    def get_contacts(self, obj_id_a=None, obj_id_b=None, only_closest_contact_per_pair=True,
                     include_collision_with_static=True):
        """
        Return all current contacts. When include_collision_with_statics is true, include contacts with static bodies
        """
        bid_a = self._obj_id_to_body[obj_id_a].bid if obj_id_a else -1
        bid_b = self._obj_id_to_body[obj_id_b].bid if obj_id_b else -1
        cs = p.getContactPoints(bodyA=bid_a, bodyB=bid_b, physicsClientId=self._pid)
        contacts = self._convert_contacts(cs)
        del cs  # NOTE force garbage collection of pybullet objects

        if not include_collision_with_static:
            def not_contact_with_static(c):
                static_a = self._obj_id_to_body[c.idA].static
                static_b = self._obj_id_to_body[c.idB].static
                return not static_a and not static_b
            contacts = filter(not_contact_with_static, contacts)
            # print(f'#all_contacts={len(all_contacts)} to #non_static_contacts={len(non_static_contacts)}')

        if only_closest_contact_per_pair:
            def bid_pair_key(x):
                return str(x.idA) + '_' + str(x.idB)
            contacts = sorted(contacts, key=bid_pair_key)
            min_dist_contact_by_pair = {}
            for k, g in groupby(contacts, key=bid_pair_key):
                min_dist_contact = min(g, key=lambda x: x.distance)
                min_dist_contact_by_pair[k] = min_dist_contact
            contacts = min_dist_contact_by_pair.values()

        # convert into dictionary of form (id_a, id_b) -> Contact
        contacts_dict = {}
        for c in contacts:
            key = (c.idA, c.idB)
            contacts_dict[key] = c

        return contacts_dict

    def _set_single_collection_filter_group_for_body(self, body):
        p.setCollisionFilterGroupMask(body.bid, -1, 1, 0xff, physicsClientId=self._pid)

    def set_single_collection_filter_group(self):
        for obj_id, body in self._obj_id_to_body.items():
            self._set_single_collection_filter_group_for_body(body)
        self._use_single_collection_filter_group = True

    def _convert_contacts(self, contacts):
        out = []
        if not contacts:
            return out
        for c in contacts:
            bid_a = c[1]
            bid_b = c[2]
            if bid_a not in self._bid_to_body or bid_b not in self._bid_to_body:
                continue
            id_a = self._bid_to_body[bid_a].id
            id_b = self._bid_to_body[bid_b].id
            o = Contact(flags=c[0], idA=id_a, idB=id_b, linkIndexA=c[3], linkIndexB=c[4],
                        positionOnAInWS=c[5], positionOnBInWS=c[6], contactNormalOnBInWS=c[7],
                        distance=c[8], normalForce=c[9])
            out.append(o)
        return out

    def ray_test(self, from_pos, to_pos):
        hit = p.rayTest(rayFromPosition=from_pos, rayToPosition=to_pos, physicsClientId=self._pid)
        intersection = Intersection._make(*hit)
        del hit  # NOTE force garbage collection of pybullet objects
        if intersection.id >= 0:  # if intersection, replace bid with id
            intersection = intersection._replace(id=self._bid_to_body[intersection.id].id)
        return intersection
