import onnx
from onnx import numpy_helper

from onnx2pytorch.utils import to_pytorch_params

TENSOR_PROTO_MAPPING = dict([i[::-1] for i in onnx.TensorProto.DataType.items()])

AttributeType = dict(
    UNDEFINED=0,
    FLOAT=1,
    INT=2,
    STRING=3,
    TENSOR=4,
    GRAPH=5,
    SPARSE_TENSOR=11,
    FLOATS=6,
    INTS=7,
    STRINGS=8,
    TENSORS=9,
    GRAPHS=10,
    SPARSE_TENSORS=12,
)


def extract_attr_values(attr):
    """Extract onnx attribute values."""
    if attr.type == AttributeType["INT"]:
        value = attr.i
    elif attr.type == AttributeType["FLOAT"]:
        value = attr.f
    elif attr.type == AttributeType["INTS"]:
        value = tuple(attr.ints)
    elif attr.type == AttributeType["FLOATS"]:
        value = tuple(attr.floats)
    elif attr.type == AttributeType["TENSOR"]:
        value = numpy_helper.to_array(attr.t)
    elif attr.type == AttributeType["STRING"]:
        value = attr.s.decode()
    else:
        raise NotImplementedError(
            "Extraction of attribute type {} not implemented.".format(attr.type)
        )
    return value


def extract_attributes(node):
    """Extract onnx attributes. Map onnx feature naming to pytorch."""
    kwargs = {}
    for attr in node.attribute:
        if attr.name == "dilations":
            kwargs["dilation"] = extract_attr_values(attr)
        elif attr.name == "group":
            kwargs["groups"] = extract_attr_values(attr)
        elif attr.name == "kernel_shape":
            kwargs["kernel_size"] = extract_attr_values(attr)
        elif attr.name == "pads":
            kwargs["padding"] = to_pytorch_params(extract_attr_values(attr))
        elif attr.name == "strides":
            kwargs["stride"] = extract_attr_values(attr)
        elif attr.name == "axis" and node.op_type == "Flatten":
            kwargs["start_dim"] = extract_attr_values(attr)
        elif attr.name == "axis" or attr.name == "axes":
            v = extract_attr_values(attr)
            if isinstance(v, (tuple, list)) and len(v) == 1:
                kwargs["dim"] = v[0]
            else:
                kwargs["dim"] = v
        elif attr.name == "keepdims":
            kwargs["keepdim"] = bool(extract_attr_values(attr))
        elif attr.name == "epsilon":
            kwargs["eps"] = extract_attr_values(attr)
        elif attr.name == "momentum":
            kwargs["momentum"] = extract_attr_values(attr)
        elif attr.name == "ceil_mode":
            kwargs["ceil_mode"] = bool(extract_attr_values(attr))
        elif attr.name == "value":
            kwargs["constant"] = extract_attr_values(attr)
        elif attr.name == "perm":
            kwargs["dims"] = extract_attr_values(attr)
        elif attr.name == "split":
            kwargs["split_size_or_sections"] = extract_attr_values(attr)
        elif attr.name == "spatial":
            kwargs["spatial"] = extract_attr_values(attr)  # Batch norm parameter
        elif attr.name == "to":
            kwargs["dtype"] = TENSOR_PROTO_MAPPING[extract_attr_values(attr)].lower()
        elif attr.name == "mode":
            kwargs["mode"] = extract_attr_values(attr)
        elif attr.name == "transB":
            kwargs["transpose_weight"] = not extract_attr_values(attr)
        elif attr.name == "transA":
            kwargs["transpose_activation"] = bool(extract_attr_values(attr))
        elif attr.name == "alpha":
            kwargs["weight_multiplier"] = extract_attr_values(attr)
        elif attr.name == "beta":
            kwargs["bias_multiplier"] = extract_attr_values(attr)
        elif attr.name == "auto_pad":
            value = extract_attr_values(attr)
            if value == "NOTSET":
                pass
            else:
                raise NotImplementedError(
                    "auto_pad={} functionality not implemented.".format(value)
                )
        else:
            raise NotImplementedError(
                "Extraction of attribute {} not implemented.".format(attr.name)
            )
    return kwargs
