#!/usr/bin/env python3

# Copyright (c) 2021, Ericson Joseph
# 
# All rights reserved.
# 
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
# 
#     * Redistributions of source code must retain the above copyright notice,
#       this list of conditions and the following disclaimer.
#     * Redistributions in binary form must reproduce the above copyright notice,
#       this list of conditions and the following disclaimer in the documentation
#       and/or other materials provided with the distribution.
#     * Neither the name of pyMakeTool nor the names of its contributors
#       may be used to endorse or promote products derived from this software
#       without specific prior written permission.
# 
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import re
import os
import subprocess
import argparse
from pathlib import Path
import json
from typing import Text
import gi

gi.require_version("Gtk", "3.0")
from gi.repository import Gtk

# ---------------------

def toKB(num):
    if num < 1024:
        return '{} B'.format(num)
    return '{:.2f} KB'.format(num/1024)

class SectionHeader:
    def __init__(self, args):
        self.nr   = args['nr']
        self.name = args['name']
        self.type = args['type']
        self.addr = int("0x"+args['addr'], 16)
        self.off  = int("0x"+args['off'], 16)
        self.size = int("0x"+args['size'], 16)
        self.es   = args['es']
        self.flg  = args['flg']
        self.lk   = int(args['lk'])
        self.inf  = int(args['inf'])
        self.al   = int(args['al'])
        self.load_addr = 0
        self.symbols = []

    def setLoadAddr(self, load_addr):
        self.load_addr = load_addr
    
    def toJSON(self):
        return json.dumps(self, default=lambda x: x.__dict__, indent=4)

    def __repr__(self) -> str:
        return json.dumps(self.__dict__)


class SectionMap:
    def __init__(self, name, addr, length, loadAddr=0):
        self.name = name
        self.addr = addr
        self.loadAddr = loadAddr
        self.length = length
    def __repr__(self) -> str:
        return json.dumps(self.__dict__)

class Symbol:
    def __init__(self, args):
        self.num    = int(args['num'])
        self.value  = int("0x"+args['addr'], 16)
        self.addr   = self.value
        self.size   = int(args['size'])
        self.type   = args['type']
        self.bind   = args['bind']
        self.vis    = args['vis']
        self.ndx    = args['ndx']
        self.name   = args['name']
    def __repr__(self) -> str:
        return json.dumps(self.__dict__)

class MemRegion:
    def __init__(self, name, attr, origin, length):
        self.name = name
        self.attr = attr
        self.origin = origin
        self.length = length
        self.end = origin + length
        self.using = 0
        self.sections = []
    def __repr__(self) -> str:
        return json.dumps(self.__dict__, default=lambda x: x.__dict__)
    def toJSON(self):
        return json.dumps(self, default=lambda x: x.__dict__, indent=4)

class GtkRegionsWin:
    def __init__(self, regions:list, gladefile):
        self.regions = regions
        self.gladefile = gladefile
        self.current_filter = None

    def getSectionRow(self, sec):
        auxrow = []
        auxrow.append([sec.name, sec.name])
        auxrow.append([sec.addr, "{0:#0{1}x}".format(sec.addr,10)])
        if sec.load_addr > 0:
            auxrow.append([sec.load_addr, "{0:#0{1}x}".format(sec.load_addr,10)])
        else:
            auxrow.append([0, ""])
        auxrow.append([sec.size, toKB(sec.size)])

        row = []
        for aux in auxrow:
            row.append(aux[0])
            row.append(aux[1])
        return row

    def getMemRegionRow(self, r):
        auxrow = []
        auxrow.append([r.name, r.name])
        auxrow.append([r.origin, "{0:#0{1}x}".format(r.origin,10)])
        auxrow.append([r.end, "{0:#0{1}x}".format(r.end,10)])
        auxrow.append([r.length, toKB(r.length)])
        auxrow.append([r.length-r.using, toKB(r.length-r.using)])
        auxrow.append([r.using, toKB(r.using)])
        auxrow.append([(r.using/r.length)*100, '{:.2f}%'.format((r.using/r.length)*100)])
        auxrow.append(['', ''])
        row = []
        for aux in auxrow:
            row.append(aux[0])
            row.append(aux[1])
        return row

    def getRegionRow(self, r):
        auxrow = []
        auxrow.append([r.name, r.name])
        auxrow.append([r.origin, "{0:#0{1}x}".format(r.origin,10)])
        auxrow.append([0, ""])
        auxrow.append([r.length, toKB(r.length)])
        row = []
        for aux in auxrow:
            row.append(aux[0])
            row.append(aux[1])
        return row

    def getSymbolRow(self, sec, sym):
        auxrow = []
        auxrow.append([sym.name, sym.name])
        auxrow.append([sym.addr, "{0:#0{1}x}".format(sym.addr,10)])
        if sec.load_addr > 0:
            diff = (sym.addr - sec.addr)
            loadAddr = sec.load_addr + diff
            auxrow.append([loadAddr, "{0:#0{1}x}".format(loadAddr,10)])
        else:
            auxrow.append([0,""])
        auxrow.append([sym.size, toKB(sym.size)])

        row = []
        for aux in auxrow:
            row.append(aux[0])
            row.append(aux[1])

        return row

    def getGtkTreeStore(self, search=None):
        self.store = Gtk.TreeStore(
            str, str,
            int, str, 
            int, str,
            int, str)
        if not search:  
            for r in self.regions:
                region_node = self.store.append(None, self.getRegionRow(r))
                for sec in r.sections:
                    section_node = self.store.append(region_node, self.getSectionRow(sec))
                    for sym in sec.symbols:
                        self.store.append(section_node, self.getSymbolRow(sec, sym))
            return self.store
        else:
            regex = f"[{0}]".format(search)
            for r in self.regions:
                region_node = self.store.append(None, self.getRegionRow(r))
                matchsec = False
                for sec in r.sections:
                    matchname = False
                    matches = []

                    for sym in sec.symbols:
                        if sym.name.startswith(search) or sec.name.startswith(search):
                            matchname = True
                            matches.append(sym)

                    if sec.name.startswith(search):
                        matchname = True

                    if matchname:
                        matchsec = True
                        section_node = self.store.append(region_node, self.getSectionRow(sec))
                        for sym in matches:
                            self.store.append(section_node, self.getSymbolRow(sec, sym))
                if not matchsec:
                    self.store.remove(region_node)
            return self.store

    def getGtkRegionsModel(self):
        store = Gtk.ListStore(
            str, str,
            int, str,
            int, str,
            int, str,
            int, str, 
            int, str,
            float, str,
            str, str)
        for r in self.regions:
            store.append(self.getMemRegionRow(r))
        # store.set_sort_column_id(0, Gtk.SortType.ASCENDING)
        return store

    def show(self):
        builder = Gtk.Builder()
        builder.add_from_file(self.gladefile)

#       Memory Regios
        self.tvregios = builder.get_object("tvRegions")
        headers = [
            'Region',
            'Start Address',
            'End Address',
            'Size',
            'Free',
            'Used'
        ]
        idx = 0
        for h in headers:
            renderer_editabletext = Gtk.CellRendererText()
            renderer_editabletext.set_property("editable", True)
            column_editabletext = Gtk.TreeViewColumn(
                h, renderer_editabletext, text=idx+1
            )
            column_editabletext.set_resizable(True)
            column_editabletext.set_sort_column_id(idx)
            self.tvregios.append_column(column_editabletext)
            idx+=2

        renderer_progress = Gtk.CellRendererProgress()
        renderer_progress.set_fixed_size(200,20)
        column_progress = Gtk.TreeViewColumn(
            "Using", renderer_progress, value=idx, inverted=1, text=(idx+1)
        )
        self.tvregios.append_column(column_progress)
        idx+=2

        renderer_text = Gtk.CellRendererText()
        column_text = Gtk.TreeViewColumn("", renderer_text, text=idx+1)
        self.tvregios.append_column(column_text)

        self.tvregios.set_model(self.getGtkRegionsModel())
        

#       Memory Details
        self.treeView = builder.get_object("treeView")
        self.treeView.set_reorderable(False)
        self.treeView.set_search_column(0)
        headers = [
            'Name',
            'Run address',
            'Load address',
            'Size'
        ]
        idx = 0
        for h in headers:
            renderer_editabletext = Gtk.CellRendererText()
            renderer_editabletext.set_property("editable", True)
            column_editabletext = Gtk.TreeViewColumn(
                h, renderer_editabletext, text=idx+1
            )
            column_editabletext.set_resizable(True)
            column_editabletext.set_sort_column_id(idx)
            self.treeView.append_column(column_editabletext)
            idx+=2

        entrySearch = builder.get_object("entrySearch")
        entrySearch.connect("changed", self.searchChanged)
        self.treeView.set_search_entry(entrySearch)


        self.treeView.set_model(self.getGtkTreeStore())
        self.treeView.expand_all()

        window = builder.get_object("window1")
        window.set_size_request(800, 600)
        window.connect("destroy", Gtk.main_quit)
        window.show_all()
        Gtk.main()


    def searchChanged(self, nip):
        if not nip.get_text():
            self.current_filter = None
        else:
            self.current_filter = nip.get_text()
        # self.filter.refilter()
        self.treeView.set_model(self.getGtkTreeStore(search=self.current_filter))
        self.treeView.expand_all()


class MemRegionView:

    PRINT_FORMAT = '| {0:<12}| {1:<15}| {2:<15}| {3:>12}| {4:>12}| {5:>12} {6:<11} {7:>7} |'

    def __init__(self, regions):
        self.regions = regions

    def __printBar(self, total, using, length=10, usingColor=False):
        str = []
        if total == 0:
            total = 1
            using = 0
        str.append('|')
        rate = using/total
        color = '\033[91m'
        if rate < 0.60:
            color = '\033[92m'
        elif rate < 0.90:
            color = '\033[93m'
        if usingColor:
            str.append(color)
        unit = total/length
        uunit = unit/8
        a = int(using/uunit)
        b = int(a/8)
        c = a - b*8
        str.append(chr(9608)*b)
        if c > 0: 
            str.append(chr(9615 - (1*c)))
        s = b + (1 if c > 0 else 0)
        str.append(' '*(length-s))
        if usingColor:
            str.append('\033[0m')

        str.append('|')
        return ''.join(str)

    def __printRegion(self, region):
        if (region.length == 0):
            return self.PRINT_FORMAT.format(
                        region.name,
                        hex(region.origin),
                        hex(region.end),
                        '0.0K',
                        '0.0K',
                        '0.0K',
                        self.__printBar(10, 0, usingColor=False),
                        '{:.2f}%'.format(0.0)
            )
        name = region.name
        origin = hex(region.origin)
        end = hex(region.end)
        length = toKB(region.length)
        free = toKB(region.length - region.using)
        using = toKB(region.using)
        bar = self.__printBar(region.length, region.using, usingColor=False)
        perc = '{:.2f}%'.format((region.using/region.length)*100)
        print(self.PRINT_FORMAT.format(name, origin, end, length, free, using, bar, perc))


    def __printHeader(self):
        print(self.PRINT_FORMAT.format('Region', 'Start', 'End', 'Size', 'Free', 'Used', '', 'Usage(%)'))


    def printAll(self):
        self.__printHeader()
        for r in self.regions:
            self.__printRegion(r)

# ----------------------

parser = argparse.ArgumentParser(description='Builder Analyzer for ARM firmware')
parser.add_argument('elf', type=str, help='ELF file')
parser.add_argument('-g', '--gtk', help='Show in gtk window', action="store_true")
parser.add_argument('-v', '--version', action='version', version='%(prog)s 2.0.0')
args = parser.parse_args()

cross_compile_prefix = os.environ.get('CROSS_COMPILE', '')
readelf = cross_compile_prefix + 'readelf'

elffile = args.elf
if not Path(elffile).exists():
    print("File {0} not found.".format(elffile))
    exit(-1)

mapfile = elffile.replace(".elf", ".map")
if not Path(mapfile).exists():
    print("File '{0}' not found".format(mapfile))
    mapfile = None
    exit(1)

result = subprocess.run([readelf, '-S', '--wide', elffile], stdout=subprocess.PIPE).stdout.splitlines()
sections = []
rgx_section_h = r"^\s*\[[ ]*(?P<nr>[0-9]+)\][ ]+(?P<name>[a-zA-Z\\.0-9_-]+)[ ]+(?P<type>[a-zA-Z\\.0-9_-]+)[ ]+(?P<addr>[a-f0-9]+)[ ]+(?P<off>[a-f0-9]+)[ ]+(?P<size>[a-f0-9]+)[ ]+(?P<es>[a-f0-9]{2})[ ]+(?P<flg>[a-zA-Z]*)[ ]+(?P<lk>[0-9]+)[ ]+(?P<inf>[0-9]+)[ ]+(?P<al>[0-9]+)"
for line in result:
    l = line.decode("utf-8")
    match = re.search(rgx_section_h, l)
    if match:
        s = SectionHeader(match.groupdict())
        sections.append(s)

result = subprocess.run([readelf, '-s', '--wide', elffile], stdout=subprocess.PIPE).stdout.splitlines()
symbols = []
rgx_symbols = r"^\s+(?P<num>[0-9]+):[ ]+(?P<addr>[a-f0-9]+)[ ]+(?P<size>[a-f0-9]+)[ ]+(?P<type>[a-zA-Z0-9]+)[ ]+(?P<bind>[a-zA-Z0-9]+)[ ]+(?P<vis>[a-zA-Z0-9]+)[ ]+(?P<ndx>[a-zA-Z0-9]+)[ ]?(?P<name>[a-zA-Z\\.\\$0-9_-]*)"
for line in result:
    l = line.decode("utf-8")
    match = re.search(rgx_symbols, l)
    if match:
        s = Symbol(match.groupdict())
        symbols.append(s)

symbols.sort(key=lambda x: x.addr, reverse=False)

regions = []
sectionsMap = []
if mapfile:
    memlines = False
    memNewLines = 2
    maplines = False
    mapNewLines = 2
    readStackSize = False
    mapfile = open(mapfile, 'r')
    content = []
    readNextLine = False
    upstream = []
    for line in mapfile:
        if (re.match(r'^Memory Configuration', line)):
            memlines = True

        if memlines:
            if (line.strip() == ''):
                memNewLines = memNewLines - 1
            if memNewLines == 0:
                memlines = False
            line = line.strip()
            line = re.sub(r'\s+', ' ', line)
            content.append(line)

        if re.match(r'^[.].*', line):
            if not ' ' in line.strip():
                readNextLine = True
            upstream.append(line.strip())
        else:
            if readNextLine:
                upstream[-1] = upstream[-1] + ' ' + line.strip()
                readNextLine = False
    mapfile.close()

    p = re.compile(r'^[.][a-zA-Z\\.0-9_-]+\s+0x[0-9a-fA-F]+\s+0x[0-9a-fA-F]+.*')
    upstream = list(filter(p.search, upstream))
    lines = upstream
    p = re.compile(
        r'^(?P<name>[.][a-z-A-Z\\.0-9_-]+)\s+(?P<addr>[0-9xa-fA-F]+)+\s+(?P<length>[0-9xa-fA-F]+).*')
    p2 = re.compile(
        r'^(?P<name>[.][a-z-A-Z\\.0-9_-]+)\s+(?P<addr>[0-9xa-fA-F]+)+\s+(?P<length>[0-9xa-fA-F]+)\s+load\s+address\s+(?P<load_addr>[0-9xa-fA-F]+).*')
    for l in lines:
        m = None
        loadAddr = 0
        if 'load' in l:
            m = p2.search(l)
            loadAddr = int(m.group('load_addr'), 16)
        else:
            m = p.search(l)

        if m:
            name = m.group('name')
            addr = int(m.group('addr'), 16)
            length = int(m.group('length'), 16)
            sectionsMap.append(SectionMap(name, addr, length, loadAddr))

    for line in content:
        if re.match(r'^Memory Configuration', line):
            content.remove(line)
        elif re.match(r'^Name\s+Origin\s+Length\s+Attributes', line):
            content.remove(line)
        elif re.match(r'^[*]default[*].*', line):
            content.remove(line)

    for regin in content:
        values = regin.split(' ')
        try:
            regions.append(MemRegion(
                name=values[0],
                attr=values[3],
                origin=int(values[1], 16),
                length=int(values[2], 16)
            ))
        except:
            pass
    regions.sort(key=lambda x: x.origin, reverse=False)

# ------------------------------- 

for sec in sections:
    if sec.size > 0 and sec.addr > 0:
        syms = list(filter(lambda x: x.ndx == sec.nr and x.size > 0, symbols))
        if syms:
            sec.symbols.extend(syms)

if regions and sectionsMap:
    sectionsMap = list(filter(lambda x: x.addr > 0 and x.length > 0 and x.loadAddr > 0 and x.name != '.bss', sectionsMap))
    for s in sectionsMap:
        for sec in sections:
            if (sec.name == s.name):
                sec.setLoadAddr(s.loadAddr)

    for r in regions:
        for sec in sections:
            if sec.size > 0 and sec.addr > 0:
                if (sec.addr >= r.origin and sec.addr < r.end) or (sec.load_addr > 0 and (sec.load_addr >= r.origin and sec.load_addr < r.end)):
                    r.using += sec.size
                    r.sections.append(sec)

    if args.gtk:
        win = GtkRegionsWin(regions, "pybagui.glade")
        win.show()
    else:
        view = MemRegionView(regions)
        view.printAll()