import numpy as np
import igraph as ig
import sys
import copy
import time

# Tree cluster class
#
# Constructor parameters
#   @param g a igraph.Graph object.  The graph must be directed tree.
#   @param m_list a list of numpy matrices.  Each matrix corresponds
#   to a locus cluster. The matrices must all
#   have the same number of columns, but can differ in the number of
#   rows.  The number of columns must equal to number of vertices in g
#   and column i is taken to correspond to vertex i.
#   @param K the number of components
#
#
# A tree cluster object partitions a tree into
# disjoint, connected components.  The partition
# is specified by K-1 vertices called cut vertices.
# The components of the tree are generated by cutting some
# of the edges going OUT of each cut vertex, and these edges
# are referred to as cut edges.  In practice, the tree_cluster
# object only stores the cut edges since those provide the cut
# vertices.  The cut edges break the tree into connected components,
# and each cluster is formed from the connected components
# pointed to by the cut edges emanating from a given cut vertex.
#
# Fitting is done by coordinate descent.  A cut vertex is chosen
# and is iteratively replaced by all other vertices in the tree to
# see if there is a better choice.  Similarly,
# for each cut vertex, all possible cut edge combinations associated
# with that cut vertex are considered.
class tree_cluster:

    def __init__(self, K, g, m_list):
        self.K = K
        self.m = m_list
        self.g = g

        # check consistency of g and m_list
        ncols = list(set([m.shape[1] for m in m_list]))
        if not len(ncols) == 1:
            sys.exit("matrices in m_list must have same number of columns")
        if not ncols[0] == g.vcount():
            sys.exit("number of vertices does not match number of columns")

        self.info = self.form_tree_information(g)
        # only allow 0,1,2 cut edge groups from a single vertex, more
        # is to computationaly intensive
        self.max_num_cut_edge_groups_from_vertex = 2

        # checks on K
        if K < 2 or K > len(self.info["internal"]):
            sys.exit("K out of range")

        # a list of edge lists.  cutting these edges forms the components
        self.cut_edges = None
        # assignment of each vertex in g to a component
        self.assignments = None
        # mediod of each component
        self.mediods = None


    @staticmethod
    def form_tree_information(g):
        # find the root
        nv = g.vcount()

        is_root = np.array([len(g.neighbors(i, mode=ig.IN))==0 for i in range(nv)])
        if (not np.sum(is_root) == 1):
            #pdb.set_trace()
            sys.exit("there is not one root")
        root = np.where(is_root)[0][0]

        children = [g.neighbors(v, mode=ig.OUT) for v in range(nv)]
        descendants = [g.neighborhood(v, order=int(1E6), mode="OUT") for v in range(nv)]
        leaves = [v for v in range(nv) if len(children[v]) == 0]

        # internal vertices are possible cut vertices
        internal_vertices = list(range(nv))
        for v in leaves:
            internal_vertices.remove(v)

        # define pmf on internal nodes so we can quickly sample them
        # in initialization
        internal_weights = np.array([len(descendants[v]) for v in internal_vertices])
        internal_weights = internal_weights/np.sum(internal_weights)

        info = {'root':root,
                'leaves':leaves,
                'internal':internal_vertices,
                'internal_weights':internal_weights,
                'children':children,
                'descendants':descendants}

        return info


    # pick random cut vertices and include all outgoing edges
    # as the cut edges
    def initialize_components(self,initial_cutedges = [],
                              assignments = []):
        g = self.g
        if len(assignments)>0:
            if not self.K == len(set(assignments)):
                #pdb.set_trace()
                sys.exit("number of components does not equal number of clusters")

            self.cut_edges = self.assignments2cutedges(assignments)
            self.update_components()

        else:
            if(len(initial_cutedges)>0):
                cut_edges = initial_cutedges
            else:
                cut_vertices = np.random.choice(self.info['internal'],
                                                    self.K-1,
                                                    replace=False)
                # all children will form cut edges
                cut_edges = [[(v, u) for u in g.neighbors(v, mode=ig.OUT)] for v in cut_vertices]
            self.cut_edges = cut_edges
            self.update_components()

    # cut vertices are the parent vertices of the cut edges
    def get_cut_vertices(self):
        v = [e[0][0] for e in self.cut_edges]
        return v

    def get_assignments(self):
        return self.assignments

    def get_mediods(self):
        return self.mediods

    # a utility method to convert assignments to cut edges
    def assignments2cutedges(self, assignments):
      cut_edges = []
      for cluster in range(1,np.max(assignments)+1):
         in_cluster = [vertex for vertex, a in enumerate(assignments) if a==cluster]
         cut_edges_cluster = [(u,v) for v in in_cluster for u in self.g.predecessors(v) if u not in in_cluster]
         cut_edges.append(cut_edges_cluster)

      return cut_edges

    def update_components(self):
        g_comp = self.g.copy()
        for e in self.cut_edges:
            # add a linker node
            g_comp.add_vertex(1)
            link_vertex = g_comp.vcount() - 1
            link_edges = [(link_vertex, ce[1]) for ce in e]
            g_comp.add_edges(link_edges)
            g_comp.delete_edges(e)
        # get components
        all_assignments = np.array(g_comp.components(mode=ig.WEAK).membership)
        assignments = all_assignments[range(self.g.vcount())]

        # check for situation that should never occur
        if not self.K == len(set(assignments)):
            #pdb.set_trace()
            sys.exit("number of components does not equal number of clusters")

        self.assignments = assignments
        self.update_mediods()


    def update_mediods(self):
        K = self.K
        m = self.m
        r = range(len(m))
        a = self.assignments

        cluster_means = [[np.mean(m[i][:,a==k]) for k in range(K)] for i in r]
        self.mediods = np.array(cluster_means)

    def compute_residual2(self):
        K = self.K
        mediods = self.mediods
        a = self.assignments
        m = self.m

        lm = len(m)
        col_ind = [np.where(a == k)[0] for k in range(self.K)]
        ss2 = 0

        for j in range(lm):
          for k in range(K):
            res2 = (m[j][:,col_ind[k]] - mediods[j,k])**2
            ss2 = ss2 + np.sum(res2)

        return ss2

    # modify the cut edges
    def execute_modification(self, mod, update=True):

            op = mod["operation"]
            if op == "add":
              self.cut_edges[mod["edge group"]].append(mod["edge"])
            elif op == "remove":
              self.cut_edges[mod["edge group"]].remove(mod["edge"])
            elif op == "transfer":
              self.cut_edges[mod["start edge group"]].remove(mod["edge"])
              self.cut_edges[mod["end edge group"]].append(mod["edge"])
            else:
              sys.exit("this should never happen!")

            if update:
                self.update_components()


    # Try to reorganize the cut edge groups in vertex v, with the
    # the total number of cut groups for v held fixed.
    def modify_vertex_edge_groups(self, v):

        # need more than one child
        if len(self.info["children"][v]) < 2:
            return False

        # which cut edge groups are associated with v
        v_cut_edge_ind = np.where(np.array(self.get_cut_vertices())==v)[0]
        n_cut_groups = len(v_cut_edge_ind)

        # need at least one cut group
        # REALLY, user should only pass cut vertices into this method!
        if n_cut_groups == 0:
            return False

        v_cut_edges = [self.cut_edges[i] for i in v_cut_edge_ind]
        children = self.info["children"][v]

        cut_children = [x[1] for e in v_cut_edges for x in e]
        cut_children_group = [v_cut_edge_ind[i] for i,e in enumerate(v_cut_edges) for x in e]
        non_cut_children = list(set(children).difference(cut_children))

        # gather all the modificaitons we want to make
        modifications_grow = []
        # move non cut edges into any one of the edge groups
        for child in non_cut_children:
            for index in v_cut_edge_ind:
                modifications_grow.append({'edge':(v,child),
                                      'edge group':index,
                                      'operation':"add"})

        modifications_prune = []
        # move cut edges out of edge groups
        for i,child in enumerate(cut_children):
            # only remove a child if the group has more than 1 edge
            if len(self.cut_edges[cut_children_group[i]]) > 1:
              modifications_prune.append({'edge':(v,child),
                                        'edge group':cut_children_group[i],
                                        'operation':"remove"})

        modifications_transfer = []
        # transfer cut edge between groups
        # we can only transfer edge if group has more than 1 edge
        v_cut_edge_ind_g1 = [i for i in v_cut_edge_ind if len(self.cut_edges[i]) > 1]
        for group1 in v_cut_edge_ind_g1:
          for group2 in v_cut_edge_ind:
            if group1 == group2:
              continue
            for child in [x[1] for x in self.cut_edges[group1]]:
              modifications_transfer.append({'edge':(v,child),
                                             'start edge group':group1,
                                             'end edge group': group2,
                                             'operation':"transfer"})


        modifications = modifications_grow + modifications_prune \
                        + modifications_transfer

        # modifications will be empty if all the children
        # are in cut groups that are of size 1
        if len(modifications) == 0:
            return False


        # compute loss for all the modifications
        current_loss = self.compute_residual2()
        save_cut_edges = copy.deepcopy(self.cut_edges)
        save_a = copy.deepcopy(self.assignments)
        save_m = copy.deepcopy(self.mediods)

        new_loss = np.zeros(len(modifications))
        for i,mod in enumerate(modifications):
            self.execute_modification(mod, update=True)
            new_loss[i] = self.compute_residual2()

            # update components without all the computations
            self.cut_edges = copy.deepcopy(save_cut_edges)
            self.assignments = copy.deepcopy(save_a)
            self.mediods = copy.deepcopy(save_m)

        # update if a modification improves the loss
        best_mod = np.argmin(new_loss)
        if new_loss[best_mod] < current_loss:
            mod = modifications[best_mod]
            self.execute_modification(mod, update=True)

            return True
        else:
            return False

    # Delete an edge group and try to add it to vertex v
    def remove_and_add_edge_group(self, index, v):

        v_cut_edge_ind = np.where(np.array(self.get_cut_vertices())==v)[0]
        if np.any(np.array(v_cut_edge_ind) == index):
            sys.exit(["you are calling remove_and_add_edge_group with",
                      "a vertex that contains the edge group to be",
                      "deleted.  Don't do this!",
                      "Call modify_vertex_edge_groups instead"])

        # we can't add an edge group to v if all children are
        # in edge groups AND all edge groups have only one child
        v_cut_edges = [self.cut_edges[i] for i in v_cut_edge_ind]
        v_cut_edge_sizes = [len(e) for e in v_cut_edges]

        children = self.info["children"][v]
        cut_children = [x[1] for e in v_cut_edges for x in e]
        cut_children_group = [v_cut_edge_ind[i] for i,e in enumerate(v_cut_edges) for x in e]
        non_cut_children = list(set(children).difference(cut_children))

        if (len(non_cut_children) == 0) and \
            (np.all(np.array(v_cut_edge_sizes) == 1)):
            return False

        cut_children_g1 = [x[1] for e in v_cut_edges for x in e if len(e) > 1]
        cut_children_g1_group = [v_cut_edge_ind[i] for i,e in enumerate(v_cut_edges) \
                                 for x in e if len(e) > 1]
        save_cut_edges = copy.deepcopy(self.cut_edges)
        save_loss = self.compute_residual2()

        # remove the cut edge group
        #debug_save_cut_edges = copy.deepcopy(self.cut_edges)
        #del self.cut_edges[index]
        # delete edge group index and...
        if len(non_cut_children) > 0:
            # replace with new group using non-cut child
            new_edge_group = [(v,np.random.choice(non_cut_children, size=1)[0])]
            #self.cut_edges.append(new_edge_group)
            self.cut_edges[index] = new_edge_group
            self.update_components()
        else:
            # add new group by removing a group edge and then
            # adding it as a new cut group
            i = np.random.choice(range(len(cut_children_g1)), size=1)[0]
            # we removed an edge group, so need to adjust group inds
            #if cut_children_g1_group[i] > index:
            #    cut_edge_group = cut_children_g1_group[i] - 1
            #else:
            #    cut_edge_group = cut_children_g1_group[i]
            cut_edge_group = cut_children_g1_group[i]
            mod = {'edge':(v,cut_children_g1[i]),
                   'edge group':cut_edge_group,
                   'operation':"remove"}
            self.execute_modification(mod, update=False)
            #self.cut_edges.append([(v, cut_children_g1[i])])
            self.cut_edges[index] = [(v, cut_children_g1[i])]
            self.update_components()

        while self.modify_vertex_edge_groups(v):
          pass

        if self.compute_residual2() < save_loss:
            #print(["edge", index, " --> vertex", v])
            return True
        else:
            self.cut_edges = save_cut_edges
            self.update_components()
            return False

    # coordinate descent on the cut edges

    # given a cut edge group, compute the lowest loss if we
    # move the cut edge to vertex v, over all vertices, and
    # move the cut edge to the first vertex with lowest loss
    def optimize_cut_edge(self, index=None):

        if index is None:
            index = np.random.choice(range(len(self.cut_edges)), size=1)[0]

        current_loss = self.compute_residual2()
        current_v = self.cut_edges[index][0][0]

        all_v = copy.deepcopy(self.info["internal"])
        if current_v in all_v:
          all_v.remove(current_v)
        all_v_per = np.random.permutation(all_v)


        for v in all_v_per:
          if self.remove_and_add_edge_group(index, v):
              #print(["transfered edge group from vertex", current_v,
              #       "to", v])
              self.modify_vertex_edge_groups(current_v)
              #print([current_loss, "-->", self.compute_residual2()])
              return True

        return False


    # single starting point optimization
    def optimize(self, assignments = None):
        start = time.time()

        print("beginning a single tree cluster optimization")
        if not assignments is None:
            self.initialize_components(assignments = assignments)

        if self.cut_edges is None:
            sys.exit("assignments must be passed or clustering initialized.")

        previous_loss = self.compute_residual2()
        iteration = 1
        while True:
           print(["epoch", iteration, ",current loss:", previous_loss])
           indices = np.random.permutation(self.K-1)
           for index in indices:
                self.optimize_cut_edge(index)

           current_loss = self.compute_residual2()
           if current_loss > previous_loss:
                sys.exit("unexpected state!")
           elif current_loss < previous_loss:
                previous_loss = current_loss
                iteration = iteration + 1
           else:
                break

        end = time.time()
        print(["optimization time", end - start])


    def treeplot(self,savepath='',
                 vertex_label=False,
                 m_index=None,
                 target=None):
        a = self.assignments
        g = self.g
        vs = {}

        if a is None:
            print("tree has not been initialized.",
                  "There is nothing to plot.")
            return None

        if a is not None:
            pal = ig.drawing.colors.ClusterColoringPalette(self.K)
            vs["vertex_color"] = pal.get_many(a)

        vs["bbox"] = (1200, 1000)
        if m_index is None:
            vs["vertex_size"] = 20
        else:
            vs["vertex_size"] = 30*np.mean(self.m[m_index], 0) + 1
        vs["vertex_label_size"] = 20
        if vertex_label:
            vs["vertex_label"] = g.vs['name']
        else:
            vs["vertex_label"] = [str(i) + "-" + str(a[i])  for i in range(g.vcount())]
        vs["vertex_label_dist"] = 1.5

        layout = g.layout_reingold_tilford(mode="all")

        if savepath == '':
            if target is None:
              pl = ig.plot(g, layout=layout, **vs)
              pl.show()
            else:
              pl = ig.plot(g, layout=layout, target=target)
              pl.show()
        else:
            ig.plot(g,savepath,layout=layout, **vs)
