#!/usr/bin/env python
from __future__ import division, print_function
import argparse
import numpy as np
import igraph as ig
import leidenalg
import sys


#based on https://github.com/theislab/scanpy/blob/8131b05b7a8729eae3d3a5e146292f377dd736f7/scanpy/_utils.py#L159
def get_igraph(sources_idxs_file, targets_idxs_file, weights_file, n_vertices):
    sources = np.load(sources_idxs_file) 
    targets = np.load(targets_idxs_file) 
    weights = np.load(weights_file)
    g = ig.Graph(directed=None) 
    g.add_vertices(n_vertices) # this adds adjacency.shap[0] vertices
    g.add_edges(list(zip(sources.tolist(), targets.tolist()))) 
    g.es['weight'] = weights.tolist()
    if g.vcount() != n_vertices: 
        print('WARNING: The constructed graph has only ' 
              +str(g.vcount())+' nodes. ' 
             'Your adjacency matrix contained redundant nodes.') 
    return g 


if __name__ == "__main__":
    parser = argparse.ArgumentParser() 
    parser.add_argument("--sources_idxs_file", required=True)    
    parser.add_argument("--targets_idxs_file", required=True)    
    parser.add_argument("--weights_file", required=True)    
    parser.add_argument("--n_vertices", type=int, required=True) 
    parser.add_argument("--partition_type", required=True) 
    parser.add_argument("--n_iterations", type=int, required=True) 
    parser.add_argument("--initial_membership_file", default=None,
                        required=False) 
    parser.add_argument("--seed", type=int, required=True) 
    parser.add_argument("--refine", action="store_true")

    args = parser.parse_args()

    the_graph = get_igraph(
        sources_idxs_file=args.sources_idxs_file,
        targets_idxs_file=args.targets_idxs_file,
        weights_file=args.weights_file,
        n_vertices=args.n_vertices)

    partition_type = eval("leidenalg."+args.partition_type) 
    n_iterations = args.n_iterations
    initial_membership = (None if args.initial_membership_file is None
                          else np.load(args.initial_membership_file).tolist())
    seed = args.seed

    #weights = np.array(the_graph.es['weight']).astype(np.float64).tolist()
    weights = (np.array(the_graph.es['weight']).astype(np.float64)).tolist()

    if (args.refine==False):
        partition = leidenalg.find_partition(
            graph=the_graph,
            partition_type=partition_type,
            weights=weights, 
            n_iterations=n_iterations,
            initial_membership=initial_membership,    
            seed=seed) 
    else:
        #Refine the partition suggested by initial_membership
        #code here is based on a combination of find_partition code
        # at https://github.com/vtraag/leidenalg/blob/9ffd92ada566d7cce094afd4ec5c70209609af26/src/functions.py#L26
        # and the discussion of refining communities in https://leidenalg.readthedocs.io/en/stable/advanced.html#optimiser
        constraining_partition = partition_type(
                        graph=the_graph, initial_membership=initial_membership,
                        weights=weights) 
        #if initial_membership is not specified, each node is placed in its
        # own cluster, as per https://leidenalg.readthedocs.io/en/stable/reference.html#mutablevertexpartition
        refined_partition_movenodes =\
            partition_type(graph=the_graph, weights=weights) 
        refined_partition_mergenodes =\
            partition_type(graph=the_graph, weights=weights) 

        #Github issue discussing things is here: https://github.com/vtraag/leidenalg/issues/61
        #move_nodes is what is used in the original louvain
        # (as per https://leidenalg.readthedocs.io/en/stable/advanced.html#optimiser)
        #merge nodes is used for the refinement step of Leiden

        #Due to segfault bug, I'm not doing move nodes:
        # https://github.com/vtraag/leidenalg/issues/68

        ##with move nodes
        #optimiser = leidenalg.Optimiser() 
        #optimiser.set_rng_seed(seed)
        #optimiser.move_nodes_constrained(refined_partition_movenodes,
        #                                 constraining_partition)

        #with merge nodes
        optimiser = leidenalg.Optimiser() 
        optimiser.set_rng_seed(seed)
        optimiser.merge_nodes_constrained(refined_partition_mergenodes,
                                          constraining_partition)
        partition = refined_partition_mergenodes

        ##take the partition with the best quality
        #partition = (refined_partition_movenodes if (
        #                  refined_partition_movenodes.quality() >
        #                  refined_partition_mergenodes.quality()) 
        #                  else refined_partition_mergenodes)

    quality = partition.quality()
    print("########################")
    print("Quality:",quality)
    print("Membership:")
    print("\n".join(str(x) for x in partition.membership))
    sys.stdout.flush() 
