#!/usr/bin/env python3

# Copyright 2019 ETH Zurich and University of Bologna
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# Author: Robert Balas (balasr@iis.ee.ethz.ch)

import argparse
import sys
import re
import os
import subprocess
from functools import reduce
from collections import OrderedDict

parser = argparse.ArgumentParser(prog='pulptrace',
                                 description="""Combine objdump information
                                 with an instruction trace log from a pulp
                                 core""")

parser.version = '0.1'
parser.add_argument('trace_file', type=str, help='trace log from a pulp core')
parser.add_argument('elf_file', type=str,
                    help='elf file that was ran, producing the trace log')
parser.add_argument('-o,', '--output', type=str, help="""write to file instead of
stdout""")
parser.add_argument('-t', '--truncate', action='store_true',
                    help='truncate overlong text')
parser.add_argument('-n', '--numeric', action='store_true',
                    help='show numeric register names')
parser.add_argument('--no-aliases', action='store_true',
                    help='do not use aliases for instruction names')
parser.add_argument('--cycles', action='store_true',
                    help='show cycle count extracted from log')
parser.add_argument('--time', action='store_true',
                    help='show passed time extracted from log')


args = parser.parse_args()

trace_filename = args.trace_file
elf_filename = args.elf_file


regs_map = [("x0", "zero"), ("x1", "ra"), ("x2", "sp"),
            ("x3", "gp"), ("x4", "tp"), ("x5", "t0"),
            ("x6", "t1"), ("x7", "t2"), ("x8", "s0"),
            ("x9", "s1"), ("x10", "a0"), ("x11", "a1"),
            ("x12", "a2"), ("x13", "a3"), ("x14", "a4"),
            ("x15", "a5"), ("x16", "a6"), ("x17", "a7"),
            ("x18", "s2"), ("x19", "s3"), ("x20", "s4"),
            ("x21", "s5"), ("x22", "s6"), ("x23", "s7"),
            ("x24", "s8"), ("x25", "s9"), ("x26", "s10"),
            ("x27", "s11"), ("x28", "t3"), ("x29", "t4"),
            ("x30", "t5"), ("x31", "t6")]
# augment regs_map with prefixes/postfixes to prevent false positives
# ugly, but works
tmp_regs_map = []
for item in regs_map:
    k, v = item
    tmp_regs_map.append([' ' + k, ' ' + v])
    tmp_regs_map.append(['(' + k, '(' + v])
    tmp_regs_map.append([k + ':', v + ':'])
    tmp_regs_map.append([k + '=', v + '='])

regs_map = tmp_regs_map

# we want to replace the higher numbered register first with their alias,
# otherwise x31 could be replaced to gp1
regs_alias = OrderedDict(reversed(regs_map))
objdump_insns = dict()
# parse objdump output and generate hashmap of address and insn string

objdump_bin = ''
if os.getenv('RISCV'):
    objdump_bin = os.getenv('RISCV') + '/bin/' + 'riscv32-unknown-elf-objdump'
else:
    objdump_bin = 'riscv32-unknown-elf-objdump'

with subprocess.Popen([objdump_bin, "--prefix-addresses"]
                      + (['-Mnumeric'] if args.numeric else [])
                      + (['-Mno-aliases'] if args.no_aliases else [])
                      + ["-d", elf_filename],
                      stdout=subprocess.PIPE) as proc:
    for line in proc.stdout:
        line = line.decode("ascii")
        if line == '':
            break
        match = re.match(r'^\s*([0-9a-f]+)\s+(<[0-9a-zA-Z+_]*>)\s+(.*)', line)
        if match:
            # group(1) = instruction address
            # group(2) = instruction address symbolic
            # group(3) = instruction name
            objdump_insns[int(match.group(1), 16)] = (
                match.group(2), match.group(3).replace("\t", " "))


def truncate_string(string, length):
    return string[:length-2] + (string[length-2:] and '..')


# redirect to stdout to file if desired
sys.stdout = open(args.output, "w") if args.output else sys.stdout

with open(trace_filename, "r") as f:
    pc = 0
    last_irq = False
    # skip trace file "header"
    f.readline()
    # parse instructions
    for line in f:
        insn_line = line.split()
        time = insn_line[0]
        cycles = insn_line[1]
        addr = insn_line[2]
        # insn_bytes = insn_line[3]
        # insn_str = insn_line[4]
        # insn_rest = insn_line[5::]
        reg_vals = ""
        insn_only = ""
        # this is a dirty heuristic which figures out if we have register
        # values in the trace file TODO: improve
        bound = 80
        if len(line) > bound:
            reg_vals = line[bound:].strip()
            insn_only = line[:bound-1].strip()

        insn_addr = int(addr.replace("x", "0"), 16)

        if not(args.numeric):
            # TODO: this might not be reliable if we have values like x10 in
            # the registers
            reg_vals = reduce(lambda a, kv: a.replace(*kv),
                              regs_alias.items(), reg_vals)

        if args.time:
            print('%-12s ' % time, end='')

        if args.cycles:
            print('%-12d ' % (int(cycles)), end='')

        if insn_addr in objdump_insns:
            source_location, objdump_insn_str = objdump_insns[insn_addr]
            if args.truncate:
                source_location = truncate_string(source_location, 40)
                objdump_insn_str = truncate_string(objdump_insn_str, 40)
            print("%08x: %-40s %-40s %-20s" % (insn_addr,
                                               source_location,
                                               objdump_insn_str,
                                               reg_vals))
        else:
            match = re.match(r'^\s*[0-9a-f]+[nmu]s\s+[0-9]+\s+[0-9a-f]+\s+[0-9a-f]+\s+(.*)',
                             insn_only)
            insn_str = match.group(1)

            if not(args.numeric):
                insn_str = reduce(lambda a, kv: a.replace(*kv),
                                  regs_alias.items(), insn_str)

            print("%08x: %-40s %-40s %-20s" % (insn_addr,
                                               "",  # no objdump info
                                               insn_str,  # insn name
                                               reg_vals))

if args.output:
    sys.stdout.close()
