import os
import pickle
import functools

from autorecsys.searcher.core.trial import Stateful
from autorecsys.searcher.core import hyperparameters as hp_module
from autorecsys.pipeline import base

import tensorflow as tf
from tensorflow.python.util import nest


class Graph(Stateful):
    """A graph consists of connected Blocks, HyperBlocks

    # Arguments
        inputs: A list of input node(s) for the Graph.
        outputs: A list of output node(s) for the Graph.
    """
    def __init__(self, inputs, outputs):
        super().__init__()
        # TODO flatten inputs & outputs
        self.inputs = nest.flatten(inputs)
        self.outputs = nest.flatten(outputs)
        # reverse order of the topological sort
        self._node_to_id = {}
        self._nodes = []
        # topological sort of the blocks in the graph
        self._blocks = []
        self._block_to_id = {}
        self._build_network()

    def compile(self, func):
        """Share the information between blocks by calling functions in compiler.

        # Arguments
            func: A dictionary. The keys are the block classes. The values are
                corresponding compile functions.
        """
        for block in self._blocks:
            if block.__class__ in func:
                func[block.__class__](block)


    def _build_network(self):
        self._node_to_id = {}

        # Recursively find all the interested nodes.
        for input_node in self.inputs:
            self._search_network(input_node, self.outputs, set(), set())
        # the topological sort of the graph in reverse order
        self._nodes = sorted(list(self._node_to_id.keys()),
                             key=lambda x: self._node_to_id[x])

        for node in (self.inputs + self.outputs):
            if node not in self._node_to_id:
                raise ValueError('Inputs and outputs not connected.')

        # Find the blocks.
        blocks = []
        for input_node in self._nodes:
            for block in input_node.out_blocks:
                if any([output_node in self._node_to_id
                        for output_node in block.outputs]) and block not in blocks:
                    blocks.append(block)

        # Check if all the inputs of the blocks are set as inputs.
        for block in blocks:
            for input_node in block.inputs:
                if input_node not in self._node_to_id:
                    raise ValueError('A required input is missing for HyperModel '
                                     '{name}.'.format(name=block.name))

        # Calculate the in degree of all the nodes
        in_degree = [0] * len(self._nodes)
        for node_id, node in enumerate(self._nodes):
            in_degree[node_id] = len([
                block for block in node.in_blocks if block in blocks])

        # Add the blocks in topological order.
        self._blocks = []
        self._block_to_id = {}
        while len(blocks) != 0:
            new_added = []

            # Collect blocks with in degree 0.
            for block in blocks:
                if any([in_degree[self._node_to_id[node]]
                        for node in block.inputs]):
                    continue
                new_added.append(block)

            # Remove the collected blocks from blocks.
            for block in new_added:
                blocks.remove(block)

            for block in new_added:
                # Add the collected blocks to the AutoModel.
                self._add_block(block)

                # Decrease the in degree of the output nodes.
                for output_node in block.outputs:
                    if output_node not in self._node_to_id:
                        continue
                    output_node_id = self._node_to_id[output_node]
                    in_degree[output_node_id] -= 1

    def _search_network(self, input_node, outputs, in_stack_nodes,
                        visited_nodes):
        visited_nodes.add(input_node)
        in_stack_nodes.add(input_node)

        outputs_reached = False
        if input_node in outputs:
            outputs_reached = True

        for block in input_node.out_blocks:
            for output_node in block.outputs:
                if output_node in in_stack_nodes:
                    raise ValueError('The network has a cycle.')
                if output_node not in visited_nodes:
                    self._search_network(output_node, outputs, in_stack_nodes,
                                         visited_nodes)
                if output_node in self._node_to_id.keys():
                    outputs_reached = True

        if outputs_reached:
            self._add_node(input_node)

        in_stack_nodes.remove(input_node)

    def _add_block(self, block):
        if block not in self._blocks:
            block_id = len(self._blocks)
            self._block_to_id[block] = block_id
            self._blocks.append(block)

    def _add_node(self, input_node):
        if input_node not in self._node_to_id:
            self._node_to_id[input_node] = len(self._node_to_id)

    def _get_block(self, name):
        for block in self._blocks:
            if block.name == name:
                return block
        raise ValueError('Cannot find block named {name}.'.format(name=name))

    def get_state(self):
        block_state = {str(block_id): block.get_state()
                       for block_id, block in enumerate(self._blocks)}
        node_state = {str(node_id): node.get_state()
                      for node_id, node in enumerate(self._nodes)}
        return {'blocks': block_state, 'nodes': node_state}

    def set_state(self, state):
        block_state = state['blocks']
        node_state = state['nodes']
        for block_id, block in enumerate(self._blocks):
            block.set_state(block_state[str(block_id)])
        for node_id, node in enumerate(self._nodes):
            node.set_state(node_state[str(node_id)])

    def save(self, fname):
        state = self.get_state()
        with tf.io.gfile.GFile(fname, 'wb') as f:
            pickle.dump(state, f)
        return str(fname)

    def reload(self, fname):
        with tf.io.gfile.GFile(fname, 'rb') as f:
            state = pickle.load(f)
        self.set_state(state)

    def build(self, hp):
        pass

class PreprocessGraph(Graph):
    """A graph consists of only Preprocessors.
    It is both a search space with Hyperparameters and a model to be fitted. It
    preprocess the dataset with the Preprocessors. The output is the input to the
    Keras model. It does not extend Hypermodel class because it cannot be built into
    a Keras model.
    """

    def preprocess(self, dataset, validation_data=None, fit=False):
        """Preprocess the data to be ready for the Keras Model.
        # Arguments
            dataset: tf.data.Dataset. Training data.
            validation_data: tf.data.Dataset. Validation data.
            fit: Boolean. Whether to fit the preprocessing layers with x and y.
        # Returns
            if validation data is provided.
            A tuple of two preprocessed tf.data.Dataset, (train, validation).
            Otherwise, return the training dataset.
        """
        dataset = self._preprocess(dataset, fit=fit)
        if validation_data:
            validation_data = self._preprocess(validation_data)
        return dataset, validation_data

    def _preprocess(self, dataset, fit=False):
        # A list of input node ids in the same order as the x in the dataset.
        input_node_ids = [self._node_to_id[input_node] for input_node in self.inputs]

        # Iterate until all the model inputs have their data.
        while set(map(lambda node: self._node_to_id[node], self.outputs)
                  ) - set(input_node_ids):
            # Gather the blocks for the next iteration over the dataset.
            blocks = []
            for node_id in input_node_ids:
                for block in self._nodes[node_id].out_blocks:
                    if block in self._blocks:
                        blocks.append(block)
            if fit:
                # Iterate the dataset to fit the preprocessors in current depth.
                self._fit(dataset, input_node_ids, blocks)

            # Transform the dataset.
            output_node_ids = []
            dataset = dataset.map(functools.partial(
                self._transform,
                input_node_ids=input_node_ids,
                output_node_ids=output_node_ids,
                blocks=blocks,
                fit=fit))

            # Build input_node_ids for next depth.
            input_node_ids = output_node_ids
        return dataset

    def _fit(self, dataset, input_node_ids, blocks):
        # Iterate the dataset to fit the preprocessors in current depth.
        for x, y in dataset:
            x = nest.flatten(x)
            id_to_data = {
                node_id: temp_x for temp_x, node_id in zip(x, input_node_ids)
            }
            for block in blocks:
                data = [id_to_data[self._node_to_id[input_node]]
                        for input_node in block.inputs]
                block.update(data, y=y)

        # Finalize and set the shapes of the output nodes.
        for block in blocks:
            block.finalize()
            nest.flatten(block.outputs)[0].shape = block.output_shape

    def _transform(self,
                   x,
                   y,
                   input_node_ids,
                   output_node_ids,
                   blocks,
                   fit=False):
        x = nest.flatten(x)
        id_to_data = {
            node_id: temp_x
            for temp_x, node_id in zip(x, input_node_ids)
        }
        output_data = {}
        # Transform each x by the corresponding block.
        for hm in blocks:
            data = [id_to_data[self._node_to_id[input_node]]
                    for input_node in hm.inputs]
            data = tf.py_function(functools.partial(hm.transform, fit=fit),
                                  inp=nest.flatten(data),
                                  Tout=hm.output_types())
            data = nest.flatten(data)[0]
            data.set_shape(hm.output_shape)
            output_data[self._node_to_id[hm.outputs[0]]] = data
        # Keep the Keras Model inputs even they are not inputs to the blocks.
        for node_id, data in id_to_data.items():
            if self._nodes[node_id] in self.outputs:
                output_data[node_id] = data

        for node_id in sorted(output_data.keys()):
            output_node_ids.append(node_id)
        return tuple(map(
            lambda node_id: output_data[node_id], output_node_ids)), y

    def build(self, hp):
        """Obtain the values of all the HyperParameters.
        Different from the build function of Hypermodel. This build function does not
        produce a Keras model. It only obtain the hyperparameter values from
        HyperParameters.
        # Arguments
            hp: HyperParameters.
        """
        super().build(hp)
        # self.compile(compiler.BEFORE)
        for block in self._blocks:
            block.build(hp)


class KerasGraph(Graph, base.HyperModel):
    """A graph and HyperModel to be built into a Keras model."""

    def build(self, hp):
        """Build the HyperModel into a Keras Model."""
        super().build(hp)
        # self.compile(compiler.AFTER)
        real_nodes = {}
        for input_node in self.inputs:
            node_id = self._node_to_id[input_node]
            real_nodes[node_id] = input_node.build()
        for block in self._blocks:
            temp_inputs = [real_nodes[self._node_to_id[input_node]]
                           for input_node in block.inputs]
            outputs = block.build(hp, inputs=temp_inputs)
            outputs = nest.flatten(outputs)
            for output_node, real_output_node in zip(block.outputs, outputs):
                real_nodes[self._node_to_id[output_node]] = real_output_node
        model = tf.keras.Model(
            [real_nodes[self._node_to_id[input_node]] for input_node in
             self.inputs],
            [real_nodes[self._node_to_id[output_node]] for output_node in
             self.outputs])

        return self._compile_keras_model(hp, model)

    def _get_metrics(self):
        # metrics = {}
        metrics = []
        for output_node in self.outputs:
            block = output_node.in_blocks[0]
            if 'optimizer' in str(type(block)):
                # metrics[block.name] = block.metric
                metrics.append(block.metric)
        return metrics

    def _get_loss(self):
        # loss = {}
        loss = []
        for output_node in self.outputs:
            block = output_node.in_blocks[0]
            if 'optimizer' in str(type(block)):
                # loss[block.name] = block.loss
                loss.append(block.loss)
        return loss

    def _compile_keras_model(self, hp, model):
        # Specify hyperparameters from compile(...)
        optimizer = hp.Choice('optimizer',
                              ['adam',
                               # 'adadelta',
                               # "Adagrad",
                               # "RMSprop",
                               #  "AdaMax",
                               # 'sgd'
                               ])

        model.compile(optimizer=optimizer,
                      metrics=self._get_metrics(),
                      loss=self._get_loss())

        return model


class PlainGraph(Graph):
    """A graph built from a HyperGraph to produce KerasGraph and PreprocessGraph.
    A PlainGraph does not contain HyperBlock. HyperGraph's hyper_build function
    returns an instance of PlainGraph, which can be directly built into a KerasGraph
    and a PreprocessGraph.
    # Arguments
        inputs: A list of input node(s) for the PlainGraph.
        outputs: A list of output node(s) for the PlainGraph.
    """

    def __init__(self, inputs, outputs, **kwargs):
        self._keras_model_inputs = []
        super().__init__(inputs=inputs, outputs=outputs, **kwargs)

    def _build_network(self):
        super()._build_network()
        # Find the model input nodes
        for node in self._nodes:
            if self._is_keras_model_inputs(node):
                self._keras_model_inputs.append(node)

        self._keras_model_inputs = sorted(self._keras_model_inputs,
                                          key=lambda x: self._node_to_id[x])

    @staticmethod
    def _is_keras_model_inputs(node):
        for block in node.in_blocks:
            if not isinstance(block, base.Preprocessor):
                return False
        for block in node.out_blocks:
            if not isinstance(block, base.Preprocessor):
                return True
        return False

    def build_keras_graph(self):
        return KerasGraph(self._keras_model_inputs,
                          self.outputs)

    def build_preprocess_graph(self):
        return PreprocessGraph(self.inputs,
                               self._keras_model_inputs)


def copy(old_instance):
    instance = old_instance.__class__()
    instance.set_state(old_instance.get_state())
    return instance



class HyperGraph(Graph):
    """A HyperModel based on connected Blocks and HyperBlocks.
    # Arguments
        inputs: A list of input node(s) for the HyperGraph.
        outputs: A list of output node(s) for the HyperGraph.
    """

    def __init__(self, inputs, outputs, **kwargs):
        super().__init__(inputs, outputs, **kwargs)
        # self.compile(compiler.HYPER)

    def build_graphs(self, hp):
        plain_graph = self.hyper_build(hp)
        return plain_graph.build_keras_graph()

    def save_weights(self, directory):
        for block in self._blocks:
            block_filename = os.path.join(directory, block.name)
            block.save_weights(filename=block_filename)

    def get_hyperparameters(self):
        """Get the tunable hyperperparameters from all the blocks in this pipeline."""
        hps = hp_module.HyperParameters()
        for block in self._blocks:
            params_dict = block.hyperparameters
            if params_dict:
                with hps.name_scope(block.name):
                    for param_name, single_hp in params_dict.items():
                        hps.register(param_name,
                                     single_hp.__class__.__name__,
                                     single_hp.get_config())
        return hps


    def hyper_build(self, hp):
        """Build a GraphHyperModel with no HyperBlock but only Block."""
        # Make sure get_uid would count from start.
        tf.keras.backend.clear_session()
        inputs = []
        old_node_to_new = {}
        for old_input_node in self.inputs:
            input_node = copy(old_input_node)
            inputs.append(input_node)
            old_node_to_new[old_input_node] = input_node
        for old_block in self._blocks:
            inputs = [old_node_to_new[input_node]
                      for input_node in old_block.inputs]
            if isinstance(old_block, base.HyperBlock):
                outputs = old_block.build(hp, inputs=inputs)
            else:
                outputs = copy(old_block)(inputs)
            for output_node, old_output_node in zip(outputs, old_block.outputs):
                old_node_to_new[old_output_node] = output_node
        inputs = []
        for input_node in self.inputs:
            inputs.append(old_node_to_new[input_node])
        outputs = []
        for output_node in self.outputs:
            outputs.append(old_node_to_new[output_node])

        pipe = PlainGraph(inputs, outputs)
        return pipe
