#!/usr/bin/python3

import argparse
import os
import pathlib
import platform
import sys
from typing import List, Union

import mgeconvert

sys.setrecursionlimit(1000000)

def get_targets(module):
    targets = []
    for attr in dir(module):
        if attr.startswith("mge_to") or attr.startswith("tracedmodule_to"):
            targets.append(attr)
    return targets

sub_targets = get_targets(mgeconvert)

complete_targets = " "
for targets in sub_targets:
    if 'tracedmodule_to_caffe'in targets:
        complete_targets += "tracedmodule_to_caffe mge_to_caffe "
    elif 'tracedmodule_to_onnx' in targets:
        complete_targets += "tracedmodule_to_onnx mge_to_onnx "
    elif 'tracedmodule_to_tflite' in targets:
        complete_targets += "tracedmodule_to_tflite mge_to_tflite "

complete_str = """_mgeconvert(){
    local script word
    COMPREPLY=()
    script="${COMP_WORDS[COMP_CWORD-1]}"
    word="${COMP_WORDS[COMP_CWORD]}"
    case ${script} in
        convert)
            words="%s"
            COMPREPLY=( $(compgen -W "$words" -- $word) )
            return
            ;;
        tracedmodule_to_caffe)
            words="-i --input -c --prototxt -b --caffemodel --end_point --use_empty_blobs --input_data_type --input_scales --input_zero_points --require_quantize --param_fake_quant --split_conv_relu  --fuse_bn --convert_backend  --outspec"
            COMPREPLY=( $(compgen -W "$words" -- $word) )
            return
            ;;
        mge_to_caffe)
            words="-i --input -c --prototxt -b --caffemodel --end_point --outspec"
            COMPREPLY=( $(compgen -W "$words" -- $word) )
            return
            ;;
        tracedmodule_to_onnx)
            words="-i --input -o --output --opset --graph_name --end_point --require_quantize --param_fake_quant --quantize_file_path --input_data_type --input_scales --input_zero_points  --outspec"
            COMPREPLY=( $(compgen -W "$words" -- $word) )
            return
            ;;
        mge_to_onnx)
            words="-i --input -o --output --opset --graph_name --end_point  --outspec"
            COMPREPLY=( $(compgen -W "$words" -- $word) )
            return
            ;;
        tracedmodule_to_tflite)
            words="-i --input -o --output --input_data_type --input_scales --input_zero_points --require_quantize --param_fake_quant --quantize_file_path --graph_name --mtk --end_point  --outspec"
            COMPREPLY=( $(compgen -W "$words" -- $word) )
            return
            ;;
        mge_to_tflite)
            words="-i --input -o --output --graph_name --mtk --end_point  --outspec"
            COMPREPLY=( $(compgen -W "$words" -- $word) )
            return
            ;;
        -*)
            subscript=${COMP_WORDS[1]}
            case $subscript in
                tracedmodule_to_caffe)
                    words="-i --input -c --prototxt -b --caffemodel --end_point --use_empty_blobs --input_data_type --input_scales --input_zero_points --require_quantize --param_fake_quant --split_conv_relu --fuse_relu --convert_backend  --outspec"
                    COMPREPLY=( $(compgen -W "$words" -- $word) )
                    return
                    ;;
                mge_to_caffe)
                    words="-i --input -c --prototxt -b --caffemodel --end_point --outspec"
                    COMPREPLY=( $(compgen -W "$words" -- $word) )
                    return
                    ;;
                tracedmodule_to_onnx)
                    words="-i --input -o --output --opset --graph_name --end_point --require_quantize --param_fake_quant --quantize_file_path --input_data_type --input_scales --input_zero_points --outspec"
                    COMPREPLY=( $(compgen -W "$words" -- $word) )
                    return
                    ;;
                mge_to_onnx)
                    words="-i --input -o --output --opset --graph_name --end_point  --outspec"
                    COMPREPLY=( $(compgen -W "$words" -- $word) )
                    return
                    ;;
                tracedmodule_to_tflite)
                    words="-i --input -o --output --input_data_type --input_scales --input_zero_points --require_quantize --param_fake_quant --quantize_file_path --graph_name --mtk --end_point  --outspec --remove_relu --prefer_same_pad_mode --use_int64_bias"
                    COMPREPLY=( $(compgen -W "$words" -- $word) )
                    return
                    ;;
                mge_to_tflite)
                    words="-i --input -o --output --graph_name --mtk --end_point  --outspec"
                    COMPREPLY=( $(compgen -W "$words" -- $word) )
                    return
                    ;;
            esac
            return
            ;;
    esac
}
complete -o bashdefault -F _mgeconvert convert""" % complete_targets


def init(subparsers):
    target = sys.argv[1]

    if "caffe" in target:

        def to_caffe(args):
            outspec = None
            if args.end_point is not None:
                outspec = args.end_point.split(";")
            converter_map = {
                "tracedmodule_to_caffe": mgeconvert.tracedmodule_to_caffe,
                "mge_to_caffe": mgeconvert.mge_to_caffe,
            }
            convert_backend_map = {
                "caffe": 1,
                "1":1,
                "snpe": 2,
                "2":2,
                "trt": 3,
                "3":3,
                "nnie": 4,
                "4":4
            }
            assert args.convert_backend in convert_backend_map, "unsupported backend"
            convert_backend = convert_backend_map[args.convert_backend]
            if "tracedmodule" in target:
                converter_map[target](
                    args.input,
                    prototxt=args.prototxt,
                    caffemodel=args.caffemodel,
                    use_empty_blobs=args.use_empty_blobs,
                    input_data_type=args.input_data_type,
                    input_scales=args.input_scales,
                    input_zero_points=args.input_zero_points,
                    require_quantize=args.require_quantize,
                    param_fake_quant=args.param_fake_quant,
                    split_conv_relu=args.split_conv_relu,
                    fuse_bn=args.fuse_bn,
                    quantize_file_path=args.quantize_file_path,
                    convert_backend=convert_backend,
                )
            else:
                converter_map[target](
                    args.input,
                    prototxt=args.prototxt,
                    caffemodel=args.caffemodel,
                    outspec=outspec,
                    convert_backend=convert_backend,
                )

        def caffe_parser(subparsers):
            p = subparsers.add_parser(target)
            p.set_defaults(func=to_caffe)
            p.add_argument(
                "-i",
                "--input",
                required=True,
                type=str,
                help="Input megengine dump model file",
            )
            p.add_argument(
                "-c",
                "--prototxt",
                required=True,
                type=str,
                help="Output caffe .prototxt file",
            )
            p.add_argument(
                "-b",
                "--caffemodel",
                required=True,
                type=str,
                help="Output caffe .caffemodel file",
            )

            p.add_argument(
                "--end_point",
                default=None,
                type=str,
                help="end_point is used to specify which part of the mge model should be converted",
            )

            p.add_argument(
                "--use_empty_blobs",
                default=False,
                action="store_true",
                help="use empty blobs when using caffe",
            )

            if "tracedmodule" in target:
                p.add_argument(
                    "--input_data_type",
                    default=None,
                    type=str,
                    help="input dtype of caffe",
                )

                p.add_argument(
                    "--input_scales", default=None, nargs="+", type=str,
                )

                p.add_argument(
                    "--input_zero_points", default=None, nargs="+", type=str,
                )

                p.add_argument(
                    "--require_quantize",
                    default=False,
                    action="store_true",
                    help="whether to do quantize for QAT model, if not then save qparams to the specified json file",
                )

                p.add_argument(
                    "--param_fake_quant",
                    default=False,
                    action="store_true",
                    help="whether to do fake quantize for params",
                )

                p.add_argument(
                    "--split_conv_relu",
                    default=False,
                    action="store_true",
                    help="do not use relu for snpe",
                )

                p.add_argument(
                    "--fuse_bn",
                    default=False,
                    action="store_true",
                    help="fuse bn",
                )

                p.add_argument(
                    "--quantize_file_path",
                    default="quant_params.json",
                    type=str,
                    help="the output path for quantize file",
                )

            p.add_argument(
                "--convert_backend",
                default='1',
                type=str,
                help="backend  '1' or 'caffe' for caffe, '2' or 'snpe' for snpe, '3' or 'trt' for trt, '4' or 'nnie' for nnie",
            )

        caffe_parser(subparsers)
    if "onnx" in target:

        def convert_func(args):
            outspec = None
            if args.end_point is not None:
                outspec = args.end_point.split(";")
            converter_map = {
                "tracedmodule_to_onnx": mgeconvert.tracedmodule_to_onnx,
                "mge_to_onnx": mgeconvert.mge_to_onnx,
                "onnx_to_tracedmodule": mgeconvert.onnx_to_tracedmodule,
                "onnx_to_mge": mgeconvert.onnx_to_mge,
            }

            if "tracedmodule_to" in target:
                converter_map[target](
                    args.input,
                    args.output,
                    graph_name=args.graph_name,
                    opset=args.opset,
                )
            elif "mge_to" in target:
                converter_map[target](
                    args.input,
                    args.output,
                    graph_name=args.graph_name,
                    opset=args.opset,
                    outspec=outspec,
                )
            elif "onnx_to_mge" in target:
                converter_map[target](
                    args.input,
                    args.output,
                    optimize_for_inference=args.optimize_for_inference,
                    frozen_input_shape=args.frozen_input_shape
                )

            else:
                converter_map[target](
                    args.input,
                    args.output,
                )

        def onnx_parser(subparsers):
            p = subparsers.add_parser(target)
            p.set_defaults(func=convert_func)
            p.add_argument(
                "-i",
                "--input",
                required=True,
                type=str,
                help="Input megengine dump model file or onnx model",
            )
            p.add_argument(
                "-o",
                "--output",
                required=True,
                type=str,
                help="Output onnx .onnx file or megengine model file(mge c++ model or traced module)",
            )
            p.add_argument(
                "--opset", default=8, type=int, help="Onnx opset version"
            )
            p.add_argument(
                "--graph_name", default="graph", type=str, help="Onnx graph name"
            )
            p.add_argument(
                "--end_point",
                default=None,
                type=str,
                help="end_point is used to specify which part of the mge model should be converted",
            )
            p.add_argument(
                "--frozen_input_shape",
                action="store_true",
                help="when enable, onnxsim use dynamic_input_shape=False to optimize model for current shape, it will frozen shape",
            )
            if "mge" in target:
                p.add_argument(
                    "--optimize-for-inference", action="store_true", help="output model optimize for inference"
                )

            if "tracedmodule" in target:

                p.add_argument(
                    "--input_data_type",
                    default=None,
                    type=str,
                    help="input dtype of onnx",
                )

                p.add_argument(
                    "--input_scales", default=None, nargs="+", type=str,
                )

                p.add_argument(
                    "--input_zero_points", default=None, nargs="+", type=str,
                )

                p.add_argument(
                    "--require_quantize",
                    default=False,
                    action="store_true",
                    help="whether to do quantize for QAT model, if not then save qparams to the specified json file",
                )

                p.add_argument(
                    "--param_fake_quant",
                    default=False,
                    action="store_true",
                    help="whether to do fake quantize for params",
                )

                p.add_argument(
                    "--quantize_file_path",
                    default="quant_params.json",
                    type=str,
                    help="the output path for quantize file",
                )

        onnx_parser(subparsers)
    if "tflite" in target:

        def to_tflite(args):
            if "tracedmodule" in target:
                mgeconvert.tracedmodule_to_tflite(
                    args.input,
                    output=args.output,
                    input_data_type=args.input_data_type,
                    input_scales=args.input_scales,
                    input_zero_points=args.input_zero_points,
                    require_quantize=args.require_quantize,
                    param_fake_quant=args.param_fake_quant,
                    quantize_file_path=args.quantize_file_path,
                    graph_name=args.graph_name,
                    mtk=args.mtk,
                    outspec=args.outspec,
                    remove_relu=args.remove_relu,
                    prefer_same_pad_mode=args.prefer_same_pad_mode,
                    use_int64_bias=args.use_int64_bias,
                )
            else:
                mgeconvert.mge_to_tflite(
                    args.input,
                    output=args.output,
                    graph_name=args.graph_name,
                    mtk=args.mtk,
                    outspec=args.outspec,
                    prefer_same_pad_mode=args.prefer_same_pad_mode,
                    use_int64_bias=args.use_int64_bias,
                )

        def tflite_parser(subparsers):
            p = subparsers.add_parser(target)
            p.set_defaults(func=to_tflite)
            p.add_argument(
                "-i",
                "--input",
                required=True,
                type=str,
                help="megengine dumped model file",
            )
            p.add_argument(
                "-o",
                "--output",
                required=True,
                type=str,
                help="converted TFLite model file",
            )
            if "tracedmodule" in target:
                p.add_argument(
                    "--input_data_type",
                    default=None,
                    type=str,
                    help="the dtype of input data used for quantize model input",
                )
                p.add_argument(
                    "--input_scales",
                    default=None,
                    type=str,
                    help="the scale of input data used for quantize model input",
                )
                p.add_argument(
                    "--input_zero_points",
                    default=None,
                    type=str,
                    help="the zero point of input data used for quantize model input",
                )
                p.add_argument(
                    "--require_quantize",
                    action="store_true",
                    help="whether to do quantize for QAT model, if not then save qparams to the specified json file",
                )
                p.add_argument(
                    "--param_fake_quant",
                    action="store_true",
                    help="whether to do fake quantize for params",
                )
                p.add_argument(
                    "--quantize_file_path",
                    default="quant_params.json",
                    type=str,
                    help="",
                )
                p.add_argument(
                    "--remove_relu",
                    action="store_true",
                    help="whether to remove relu operation",
                )
            p.add_argument(
                "--graph-name",
                default="graph0",
                type=str,
                help="default subgraph name in TFLite model",
            )
            p.add_argument(
                "--mtk",
                action="store_true",
                help="If target flatform is MTK(P70, P80)",
            )
            p.add_argument(
                "--outspec",
                default=None,
                type=list,
                help="the names of end points of the model",
            )
            p.add_argument(
                "--prefer_same_pad_mode",
                action="store_true",
                help="whether prefer to use SAME pad mode for conv op",
            )

            p.add_argument(
                "--use_int64_bias",
                action="store_true",
                help="whether use int64 as dtype of bias",
            )

        tflite_parser(subparsers)


def add_completion(path):
    wp = pathlib.Path(path)
    wp.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w+") as f:
        f.write(complete_str)


def main():
    targets = get_targets(mgeconvert)
    msg = targets[0] if len(targets) == 1 else "{" + ", ".join(targets) + "}"
    parser = argparse.ArgumentParser(
        description='use "convert %s -h" for more details' % msg
    )
    parser.add_argument(
        "--init", default=False, action="store_true", help="init the bash completion"
    )
    subparsers = parser.add_subparsers()
    init(subparsers)
    args = parser.parse_args()
    if args.init:
        sysstr = platform.system()
        if sysstr == "Windows":
            print("WARNING: windows doesn't support hinting")
        else:
            path = "{}/.local/share/bash-completion/completions/mgeconvert".format(
                os.environ.get("HOME")
            )
            add_completion(path)
            shell = os.environ.get("SHELL", "/usr/bash")
            if shell.find("zsh") != -1:
                print(
                    " if you don't have zsh completion init, please excute command: 'autoload -Uz compinit &&  compinit'"
                )
                print(
                    "Guess you are using zsh, please add `source %s` to your ~/.zshrc"
                    % path
                )
            elif shell.find("bash") != -1:
                print(
                    "Guess you are using bash, please relogin or do `source %s`" % path
                )
            else:
                print(
                    "Current {} doesn't support hinting shell completion".format(shell)
                )
        sys.exit(0)

    if hasattr(args, "func"):
        args.func(args)
    else:
        print("[error] please point out which framework you want to convert")
        print('use "convert -h" for help')


if __name__ == "__main__":
    main()
