# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from azureml.data._dataset import AbstractDataset
from azureml.data.data_reference import DataReference
from azure.ml.component import Component, PipelineComponent

from .component import _OutputBuilder, _InputBuilder
from ._pipeline_parameters import PipelineParameter

from ._restclients.designer.models import ComputeSetting, DatastoreSetting,\
    SubPipelinesInfo, SubPipelineDefinition, Kwarg, SubPipelineParameterAssignment,\
    Parameter, SubGraphParameterAssignment, DataPathParameter, SubGraphDataPathParameterAssignment,\
    SubGraphConnectionInfo, SubGraphPortInfo, SubGraphInfo

from ._utils import _get_data_info_hash_id


class SubPipelinesInfoBuilder(object):
    """Helper to build sub pipelines info from pipeline component.
    """
    def __init__(self, pipeline, module_node_to_graph_node_mapping, pipeline_parameters, support_local_dataset=False):
        self.pipeline = pipeline
        self.module_nodes, _ = pipeline._expand_pipeline_nodes()
        self.pipelines = pipeline._expand_pipeline_to_pipelines()
        self.input_to_data_info_dict = pipeline._build_input_to_data_info_dict(self.module_nodes,
                                                                               pipeline_parameters,
                                                                               support_local_dataset)
        self.module_node_to_graph_node_mapping = module_node_to_graph_node_mapping
        self.pipeline_parameters = pipeline_parameters

    def build(self):
        # pipeline definitions
        definitions = []
        for p in self.pipelines:
            if p._pipeline_definition not in definitions:
                definitions.append(p._pipeline_definition)

        node_id_2_graph_id = {}

        for node in self.module_nodes:
            step_id = self.module_node_to_graph_node_mapping[node._get_instance_id()]
            node_id_2_graph_id[step_id] = node.pipeline._id

        # sub graph infos
        sub_graph_infos = [self._get_sub_graph_info(p) for p in self.pipelines]

        return SubPipelinesInfo(sub_graph_info=sub_graph_infos,
                                node_id_to_sub_graph_id_mapping=node_id_2_graph_id,
                                sub_pipeline_definition=definitions)

    def _get_sub_graph_info(self, pipeline):
        def _is_default_compute_node(node):
            if isinstance(node, Component) and not isinstance(node, PipelineComponent):
                return node.runsettings.target is None
            return False

        def _is_default_datastore_node(node, default_datastore):
            if isinstance(node, Component) and not isinstance(node, PipelineComponent):
                non_default_datastore_output = next((output for output in node.outputs.values()
                                                    if output.datastore != default_datastore), None)
                return non_default_datastore_output is None
            return False

        compute = _get_compute_setting(pipeline._get_default_compute_target())
        datastore = _get_data_store_setting(pipeline.default_datastore)

        sub_graph_default_compute_target_nodes = [_get_graph_node_id(node, self.module_node_to_graph_node_mapping)
                                                  for node in pipeline.nodes
                                                  if _is_default_compute_node(node)]

        sub_graph_default_data_store_nodes = [_get_graph_node_id(node, self.module_node_to_graph_node_mapping)
                                              for node in pipeline.nodes
                                              if _is_default_datastore_node(node, pipeline._default_datastore)]

        sub_graph_parameter_assignment = self._get_parameter_assignments(pipeline)
        sub_graph_data_path_parameter_assignment = self._get_data_parameter_assignments(pipeline)

        sub_graph_input_ports = self._build_sub_graph_input_ports(pipeline)
        sub_graph_output_ports = self._build_sub_graph_output_ports(pipeline)

        sub_graph_info = SubGraphInfo(
            name=pipeline.name, description=pipeline.description,
            default_compute_target=compute, default_data_store=datastore,
            id=pipeline._id, parent_graph_id=pipeline._parent._id if pipeline._parent else None,
            pipeline_definition_id=pipeline._pipeline_definition.id,
            sub_graph_parameter_assignment=sub_graph_parameter_assignment,
            sub_graph_data_path_parameter_assignment=sub_graph_data_path_parameter_assignment,
            sub_graph_default_compute_target_nodes=sub_graph_default_compute_target_nodes,
            sub_graph_default_data_store_nodes=sub_graph_default_data_store_nodes,
            inputs=sub_graph_input_ports,
            outputs=sub_graph_output_ports)

        return sub_graph_info

    def _get_parameter_assignments(self, pipeline):
        parameter_assignments = []
        parameter_assignments_mapping = {}
        for k, v in pipeline._parameters_param.items():
            parameter_assignments_mapping[k] = []

        def find_matched_parent_input(input_value):
            is_sub_pipeline = pipeline._is_sub_pipeline
            for _, input_v in pipeline.inputs.items():
                # case1: input from root pipeline or from sub pipeline's pipeline parameter with None value
                # case2: input from sub pipeline's input
                # refer to _build_pipeline_parameter() in pipeline.py
                if input_value.dset == input_v.dset or (is_sub_pipeline and input_value.dset == input_v):
                    return True
            return False

        def get_parent_parameter_name(node, para_v):
            # for sub pipeline, it should be a PipelineParameter wrapped with _InputBuilder
            if isinstance(node, PipelineComponent) and isinstance(para_v, _InputBuilder) and \
                    isinstance(para_v._get_internal_data_source(), PipelineParameter):
                return para_v.dset.name
            elif isinstance(node, Component) and not isinstance(node, PipelineComponent) and \
                    isinstance(para_v, _InputBuilder) and \
                    isinstance(para_v._get_internal_data_source(), PipelineParameter):
                return para_v.name
            elif isinstance(node, Component) and not isinstance(node, PipelineComponent) and \
                    isinstance(para_v, PipelineParameter):
                return para_v.name
            else:
                return None

        def get_param_name(dset):
            if isinstance(dset, DataReference):
                return dset.data_reference_name
            else:
                return dset.name

        def _get_node_parameters_para_mapping(node):
            if isinstance(node, PipelineComponent):
                return node._parameters_param.items()
            else:
                return node._build_params().items()

        def try_to_add_assignments(assignments_mapping, parent_param_name, node_id, para_name):
            if parent_param_name in parameter_assignments_mapping.keys():
                assignments = parameter_assignments_mapping[parent_param_name]
                assignments.append(
                    SubPipelineParameterAssignment(node_id=node_id,
                                                   parameter_name=para_name))

        for node in pipeline.nodes:
            node_id = _get_graph_node_id(node, self.module_node_to_graph_node_mapping)
            for input_name, input_value in node.inputs.items():
                if find_matched_parent_input(input_value):
                    try_to_add_assignments(assignments_mapping=parameter_assignments_mapping,
                                           parent_param_name=get_param_name(input_value.dset),
                                           node_id=node_id,
                                           para_name=input_name)

            for para_k, para_v in _get_node_parameters_para_mapping(node):
                try_to_add_assignments(assignments_mapping=parameter_assignments_mapping,
                                       parent_param_name=get_parent_parameter_name(node, para_v),
                                       node_id=node_id,
                                       para_name=para_k)

        for k, v in pipeline._parameters_param.items():
            if len(parameter_assignments_mapping[k]) > 0:
                parameter = Parameter(name=k, default_value=v.default_value) \
                    if isinstance(v, PipelineParameter) else Parameter(name=k)
                parameter_assignments.append(SubGraphParameterAssignment(
                    parameter=parameter,
                    parameter_assignments=parameter_assignments_mapping[k]))

        return parameter_assignments

    def _get_data_parameter_assignments(self, pipeline):
        dataset_parameter_assignment = []

        for input_k, input_v in pipeline.inputs.items():
            # dataset parameter assignment
            if isinstance(input_v, _InputBuilder) and isinstance(input_v.dset, AbstractDataset):
                from ._graph import _GraphEntityBuilder
                dataset_def = _GraphEntityBuilder._get_dataset_def_from_dataset(input_v.dset)
                data_set_parameter = DataPathParameter(
                    name=input_k,
                    default_value=dataset_def.value,
                    is_optional=False,
                    data_type_id='DataFrameDirectory')
                dataset_parameter_assignment.append(SubGraphDataPathParameterAssignment(
                    data_set_path_parameter=data_set_parameter,
                    # currently, we don't need to know which node is assigned
                    data_set_path_parameter_assignments=[]))

        return dataset_parameter_assignment

    def _find_input_external_port(self, input_source, pipeline):
        # directly from data source node
        if input_source in self.input_to_data_info_dict.keys():
            node_id = _get_data_info_hash_id(self.input_to_data_info_dict[input_source])
            port_name = 'output'
            return [SubGraphConnectionInfo(port_name=port_name, node_id=node_id)]

        # from parent pipeline's input
        if isinstance(input_source, _InputBuilder):
            node_id = pipeline._parent._id
            port_name = input_source.name
            return [SubGraphConnectionInfo(port_name=port_name, node_id=node_id)]

        # from brother node's output
        if isinstance(input_source, _OutputBuilder):
            from_node = next((n for n in pipeline._parent.nodes if input_source in n.outputs.values()), None)
            if isinstance(from_node, PipelineComponent):
                node_id = from_node._id
                port_name = input_source._name
            else:
                node_id = self.module_node_to_graph_node_mapping[from_node._get_instance_id()]
                port_name = input_source.port_name
            return [SubGraphConnectionInfo(port_name=port_name, node_id=node_id)]

        return []

    def _find_input_internal_ports(self, input_val, pipeline):
        internals = []
        for node in pipeline.nodes:
            for input in node.inputs.values():
                if input.dset == input_val:
                    node_id = _get_graph_node_id(node, self.module_node_to_graph_node_mapping)
                    port_name = _get_graph_port_name(node, input)
                    internals.append(SubGraphConnectionInfo(port_name=port_name, node_id=node_id))

        return internals

    def _find_output_external_ports(self, output_val, pipeline):
        externals = []

        if pipeline._parent is None:
            return externals

        # to parent pipeline's output ports
        for output in pipeline._parent.outputs.values():
            if output_val == output:
                node_id = pipeline._parent._id
                port_name = output._name
                externals.append(SubGraphConnectionInfo(port_name=port_name, node_id=node_id))

        # to brother node's input
        for node in pipeline._parent.nodes:
            for input in node.inputs.values():
                if output_val == input.dset:
                    node_id = _get_graph_node_id(node, self.module_node_to_graph_node_mapping)
                    port_name = port_name = _get_graph_port_name(node, input)
                    externals.append(SubGraphConnectionInfo(port_name=port_name, node_id=node_id))

        return externals

    def _find_output_internal_port(self, output_val, pipeline):
        internals = []
        for node in pipeline.nodes:
            for output in node.outputs.values():
                if output_val == output:
                    node_id = node_id = _get_graph_node_id(node, self.module_node_to_graph_node_mapping)
                    port_name = output._name if isinstance(node, PipelineComponent) else output.port_name
                    internals.append(SubGraphConnectionInfo(port_name=port_name, node_id=node_id))

        return internals

    def _build_sub_graph_input_ports(self, pipeline):
        if pipeline._parent is None:
            # don't need to build input ports for root pipeline
            return []

        input_ports = []
        for input_name, input_val in pipeline.inputs.items():
            externals = self._find_input_external_port(input_val.dset, pipeline)
            internals = self._find_input_internal_ports(input_val, pipeline)
            input_ports.append(SubGraphPortInfo(name=input_name, internal=internals, external=externals))

        return input_ports

    def _build_sub_graph_output_ports(self, pipeline):
        if pipeline._parent is None:
            # don't need to build input ports for root pipeline
            return []

        output_ports = []
        for output_name, output_val in pipeline.outputs.items():
            externals = self._find_output_external_ports(output_val, pipeline)
            internals = self._find_output_internal_port(output_val, pipeline)
            output_ports.append(SubGraphPortInfo(name=output_name, internal=internals, external=externals))

        return output_ports


# region helper_functions
def _get_compute_setting(default_compute_target):
    if default_compute_target is None:
        return None
    elif isinstance(default_compute_target, str):
        return ComputeSetting(name=default_compute_target)
    elif isinstance(default_compute_target, tuple):
        if len(default_compute_target) == 2:
            return ComputeSetting(name=default_compute_target[0])
            # TODO: how to set proper compute_type
            # compute_type=default_compute_target[1])
    else:
        return ComputeSetting(name=default_compute_target.name)
        # compute_type=default_compute_target.type)


def _get_data_store_setting(default_data_store):
    from azureml.data.abstract_datastore import AbstractDatastore
    if isinstance(default_data_store, str):
        return DatastoreSetting(data_store_name=default_data_store)
    elif isinstance(default_data_store, AbstractDatastore):
        return DatastoreSetting(data_store_name=default_data_store.name)
    else:
        return None


def _correct_default_compute_target(sub_pipeline_definition, default_compute_target):
    sub_pipeline_definition.default_compute_target = _get_compute_setting(default_compute_target)


def _correct_default_data_store(sub_pipeline_definition, default_data_store):
    sub_pipeline_definition.default_data_store = _get_data_store_setting(default_data_store)


def _normalize_from_module_name(from_module_name):
    """Return the bottom module file name.

    If from_module_name = 'some_module', return 'some_module'
    If from_module_name = 'some_module.sub_module', return 'sub_module'
    """
    if from_module_name is None:
        return None

    try:
        import re
        entries = re.split(r'[.]', from_module_name)
        return entries[-1]
    except Exception:
        return None


def _build_sub_pipeline_definition(name, description,
                                   default_compute_target, default_data_store,
                                   id, parent_definition_id=None,
                                   from_module_name=None, parameters=None, func_name=None):
    def parameter_to_kv(parameter):
        from inspect import Parameter
        key = parameter.name
        value = parameter.default if parameter.default is not Parameter.empty else None
        kv = Kwarg(key=key, value=value)
        return kv

    compute_target = _get_compute_setting(default_compute_target)
    data_store = _get_data_store_setting(default_data_store)
    parameter_list = [] if parameters is None else [parameter_to_kv(p) for p in parameters]

    return SubPipelineDefinition(name=name, description=description,
                                 default_compute_target=compute_target, default_data_store=data_store,
                                 pipeline_function_name=func_name,
                                 id=id, parent_definition_id=parent_definition_id,
                                 from_module_name=_normalize_from_module_name(from_module_name),
                                 parameter_list=parameter_list)


def _get_graph_node_id(node, module_node_to_graph_node_mapping):
    if isinstance(node, PipelineComponent):
        return node._id
    else:
        return module_node_to_graph_node_mapping[node._get_instance_id()]


def _get_graph_port_name(node, input):
    if isinstance(node, PipelineComponent):
        return input.name
    else:
        return node._pythonic_name_to_input_map[input.name]
# endregion
