# Copyright 2022 Tiger Miao
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""format output"""

import os
import re
from textwrap import fill
import traceback
from prettytable import PrettyTable
from troubleshooter.common.util import print_line
from troubleshooter.common.information_build import get_errmsg_dict

TABLE_WIDTH = 50
DELIMITER_LEN = 100

# GLog level and level name
_item_to_cn = {
    'item': '项目',
    'desc': '描述',
    'ms_version': '版本信息:',
    'ms_mode': '执行模式:',
    'ms_device': '配置设备:', # GPU, CPU, Ascend
    'ms_status': '执行阶段:',
    'code_line': '代码行:',
    'cause': '可能原因:',
    'err_code': '示例错误代码：',
    'fixed_code': '示例正确代码：',
    'proposal': '处理建议:',
    'case': '相关案例:',
    'sink_mode': '下沉模式:'   # 图下沉，数据下沉，非下沉
}

def _add_row(x, item, message, width=TABLE_WIDTH, break_long_words=False, break_on_hyphens=False):
    if message is None:
        return
    item_cn = _item_to_cn.get(item)
    format_message = _format_str_length(message) if os.linesep in message else message
    x.add_row([item_cn, fill(format_message, width=width, break_long_words=break_long_words,
                             break_on_hyphens=break_on_hyphens)])

def print_result(expert_experience):
    x = PrettyTable()
    item_desc = _item_to_cn
    x.title = 'MindSpore FAR(Failure Analysis Report)'
    x.field_names = [item_desc.get("item"), item_desc.get("desc")]
    x.align[item_desc.get("desc")] = 'l'
    mindspore_version = expert_experience.get("mindspore_version")
    mindspore_mode = expert_experience.get("mindspore_mode")
    mindspore_device = expert_experience.get("Device Type")
    x.add_row([item_desc.get("ms_version"), fill(mindspore_version, width=TABLE_WIDTH)])
    x.add_row([item_desc.get("ms_mode"), fill(mindspore_mode, width=TABLE_WIDTH)])
    if mindspore_device:
        x.add_row([item_desc.get("ms_device"), fill(mindspore_device, width=TABLE_WIDTH)])
    ms_status = expert_experience.get("ms_status")
    code_line = expert_experience.get("code_line")
    sink_mode = expert_experience.get("Sink Mode")
    if ms_status:
        x.add_row([item_desc.get("ms_status"), fill(ms_status, width=TABLE_WIDTH)])
    if code_line:
        x.add_row([item_desc.get("code_line"), fill(code_line, width=TABLE_WIDTH)])
    if sink_mode:
        x.add_row([item_desc.get("sink_mode"), fill(sink_mode, width=TABLE_WIDTH)])

    # 可能原因
    fault_cause = expert_experience.get('Fault Cause')
    _add_row(x, "cause", fault_cause)

    # 错误代码
    err_code = expert_experience.get("Error Case")
    err_code = _format_code_str(err_code)

    _add_row(x, "err_code", err_code)
    # 处理建议
    suggestion = expert_experience.get("Modification Suggestion")
    _add_row(x, "proposal", suggestion)
    # 正确代码
    fixed_code = expert_experience.get("Fixed Case")
    fixed_code = _format_code_str(fixed_code)
    _add_row(x, "fixed_code", fixed_code)
    # 相关案例
    fault_case = expert_experience.get("Fault Case")
    fault_case = _format_case_str(fault_case, mindspore_version)
    _add_row(x, "case", fault_case)
    print(x.get_string())

def _print_msg(title, msg=None, print_msg=True):
    if msg:
        print_line("-", DELIMITER_LEN)
        print("-  " + title)
        print_line("-", DELIMITER_LEN)
        if print_msg:
            print(msg.rstrip(os.linesep) + os.linesep)

def _print_stack(exc_traceback_obj):
    if exc_traceback_obj:
        _print_msg("Python Traceback (most recent call last):", "NULL", False)
        traceback.print_tb(exc_traceback_obj)
        print("")

def _format_str_length(string):
    str_list = string.split(os.linesep)
    result_str = ""
    for str_tmp in str_list:
        result_str = result_str + str_tmp.ljust(TABLE_WIDTH) + os.linesep
    return result_str

def _format_code_str(content, width=50):
    if content:
        lines = content.split("\n")
        result = '+'.ljust(width-2, '-')+ '+' + os.linesep
        line = lines[1]
        if line == '':
            return content
        j = 0
        while line[j] == ' ':
            j = j + 1
        for line in lines:
            if line == lines[0]:
                continue
            line = line[j:]
            pre_line = "> " + line + os.linesep
            result += pre_line
        result += '+'.ljust(width-2, '-')+ '+'
        content = result
    return  content

# replace the case link mindspore version to match with current mindspore version

# replace mindspore version in link
def _replace(link, keys, target):
    """
    :param link: str, the case web page's link
    :param keys: list[str], regular expressions for different web page, like note, api and faq.
    :param target: list[str], the link keywords of web pages of note, api and faq, corresponding to current
                   mindspore version
    :return: link: replaced web page's link, corresponding to current mindspore version
    """
    for i, key in enumerate(keys):
        match = re.search(key, link)
        if match:
            link = link[:match.start()] + target[i] +link[match.end():]
            break
    return link

def _replace_link_version(link, link_version, mindspore_version):
    if mindspore_version < link_version:
        return link
    if mindspore_version < "r1.7" or link_version >= "r1.7":
        match = re.search(link_version, link)
        link = link[:match.start()] + mindspore_version + link[match.end():]
        return link
    if link_version < "r1.7" and mindspore_version >= "r1.7":
        keys = ["note\/zh-CN\/{}".format(link_version),
                "api\/zh-CN\/{}".format(link_version),
                "faq\/zh-CN\/{}".format(link_version)]
        target = ["zh-CN/{}/note".format(mindspore_version),
                  "zh-CN/{}".format(mindspore_version),
                  "zh-CN/{}/faq".format(mindspore_version)]
        link = _replace(link, keys, target)
    return link

def _format_case_str(content, mindspore_version):
    if content:
        # match, no replace
        lines = content.split(os.linesep)
        result = ""
        for line in lines:
            line = line.lstrip()
            match = re.search(mindspore_version, line)
            if not match:
                key = "(r1\.[6-9])|(r2\.[0-9])"
                match = re.search(key, line)
                # no mindspore link, return
                if match:
                    link_version = line[match.start():match.end()]
                    line = _replace_link_version(line, link_version, mindspore_version)
                result += line + os.linesep
            else: # link version same with mindspore version, no replace
                return content
        content = result
    return content

def print_format_exception(exc_type, exc_value, exc_traceback_obj):
    import mindspore
    ms_version = mindspore.__version__[:3]
    if ms_version >= '1.8':
        traceback.print_exc()
        return

    msg_dict = get_errmsg_dict(exc_type, exc_value)
    _print_stack(exc_traceback_obj)
    _print_msg("Error Message:", msg_dict.get("err_msg"))

    if msg_dict.get("construct_stack_msg"):
        _print_msg("The Traceback of Net Construct Code:", msg_dict.get("construct_stack_msg"))
    else:
        _print_msg("The Traceback of Net Construct Code:", msg_dict.get("construct_stack_in_file_msg"))
    _print_msg("C++ Function:", msg_dict.get("cpp_fun_msg"))
    _print_msg("Inner Message:", msg_dict.get("abstract_inner_msg"))
