#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import print_function

import argparse
import os
import sys
import yaml
import json
import h5py
import hls4ml

from shutil import copyfile

config_filename = 'hls4ml_config.yml'

hls4ml_description = """

        ╔╧╧╧╗────o
    hls ║ 4 ║ ml   - Machine learning inference in FPGAs
   o────╚╤╤╤╝
"""

def main():
    parser = argparse.ArgumentParser(description=hls4ml_description, formatter_class=argparse.RawDescriptionHelpFormatter)
    subparsers = parser.add_subparsers()
    
    config_parser = subparsers.add_parser('config', help='Create a conversion configuration file')
    convert_parser = subparsers.add_parser('convert', help='Convert Keras or ONNX model to HLS')
    build_parser = subparsers.add_parser('build', help='Build generated HLS project')
    report_parser = subparsers.add_parser('report', help='Show synthesis report of an HLS project')

    config_parser.add_argument('-m', '--model', help='Model file to convert (Keras .h5 or .json file, ONNX .onnx file, or TensorFlow .pb file)', default=None)
    config_parser.add_argument('-w', '--weights', help='Optional weights file (if Keras .json file is provided))', default=None)
    config_parser.add_argument('-p', '--project', help='Project name', default='myproject')
    config_parser.add_argument('-d', '--dir', help='Project output directory', default='my-hls-test')
    config_parser.add_argument('-f', '--fpga', help='FPGA part', default='xcku115-flvb2104-2-i')
    config_parser.add_argument('-bo', '--board', help='Board used.', default='pynq-z2')
    config_parser.add_argument('-ba', '--backend', help='Backend to use (Vivado, VivadoAccelerator)', default='Vivado')
    config_parser.add_argument('-c', '--clock', help='Clock frequency (ns)', type=int, default=5)
    config_parser.add_argument('-g', '--granularity', help='Granularity of configuration. One of "model", "type" or "name"', default='model')
    config_parser.add_argument('-x', '--precision', help='Default precision', default='ap_fixed<16,6>')
    config_parser.add_argument('-r', '--reuse-factor', help='Default reuse factor', type=int, default=1)
    config_parser.add_argument('-o', '--output', help='Output file name', default=None)
    config_parser.set_defaults(func=config)
    
    convert_parser.add_argument('-c', '--config', help='Configuration file', default=None)
    convert_parser.set_defaults(func=convert)
    
    build_parser.add_argument('-p', '--project', help='Project directory', default=None)
    build_parser.add_argument('-c', '--simulation', help='Run C simulation', action='store_true', default=False)
    build_parser.add_argument('-s', '--synthesis', help='Run C/RTL synthesis', action='store_true', default=False)
    build_parser.add_argument('-r', '--co-simulation', help='Run C/RTL co-simulation', action='store_true', default=False)
    build_parser.add_argument('-v', '--validation', help='Run C/RTL validation', action='store_true', default=False)
    build_parser.add_argument('-e', '--export', help='Export IP (implies -s)', action='store_true', default=False)
    build_parser.add_argument('-l', '--vivado-synthesis', help='Run Vivado synthesis (implies -s)', action='store_true', default=False)
    build_parser.add_argument('-a', '--all', help='Run C simulation, C/RTL synthesis, C/RTL co-simulation and Vivado synthesis', action='store_true')
    build_parser.add_argument('--reset', help='Remove any previous builds', action='store_true', default=False)
    build_parser.set_defaults(func=build)
    
    report_parser.add_argument('-p', '--project', help='Project directory', default=None)
    report_parser.add_argument('-f', '--full', help='Show full report', action='store_true', default=False)
    report_parser.set_defaults(func=report)

    parser.add_argument('--version', action='version', version='%(prog)s {version}'.format(version=hls4ml.__version__))

    args = parser.parse_args()
    if hasattr(args, 'func'):
        args.func(args)
    else:
        parser.print_usage()

def config(args):
    if args.model is None:
        print('Model file (-m or --model) must be provided.')
        sys.exit(1)
    
    config = hls4ml.utils.config.create_config(
        backend=args.backend,
        output_dir=args.dir,
        project_name=args.project,
        part=args.fpga,
        board=args.board,
        clock_period=args.clock
    )

    if args.model.endswith('.h5'):
        config['KerasH5'] = args.model

        with h5py.File(args.model, mode='r') as h5file:
            # Load the configuration from h5 using json's decode
            model_arch = h5file.attrs.get('model_config')
            if model_arch is None:
                print('No model found in the provided h5 file.')
                sys.exit(1)
            else:
                model_arch = json.loads(model_arch.decode('utf-8'))
            
            config['HLSConfig'] = hls4ml.utils.config_from_keras_model(
                model_arch,
                granularity=args.granularity,
                default_precision=args.precision,
                default_reuse_factor=args.reuse_factor
            )
    elif args.model.endswith('.json'):
        if args.weights is None:
            print('Weights file (-w or --weights) must be provided when parsing from JSON file.')
            sys.exit(1)
        config['KerasJson'] = args.model
        config['KerasH5'] = args.weights

        with open(args.model) as json_file:
            model_arch = json.load(json_file)
            config['HLSConfig'] = hls4ml.utils.config_from_keras_model(
                model_arch,
                granularity=args.granularity,
                default_precision=args.precision,
                default_reuse_factor=args.reuse_factor
            )
    elif args.model.endswith('.onnx'):
        print('Creating configuration for ONNX mdoels is not supported yet.')
        sys.exit(1)
    elif args.model.endswith('.pb'):
        print('Creating configuration for Tensorflow mdoels is not supported yet.')
        sys.exit(1)
    
    if args.output is not None:
        outname = args.output
        if not outname.endswith('.yml'):
            outname += '.yml'
        print('Writing config to {}'.format(outname))
        with open(outname, 'w') as outfile:
            yaml.dump(config, outfile, default_flow_style=False, sort_keys=False)
    else:
        print('Config output:')
        yaml.dump(config, sys.stdout, default_flow_style=False, sort_keys=False)

def convert(args):
    model = hls4ml.converters.convert_from_config(args.config)

    if model is not None:
        model.write()

def build(args):
    if args.project is None:
        print('Project directory (-p or --project) must be provided.')
        sys.exit(1)

    reset = int(args.reset)
    csim = int(args.simulation)
    synth = int(args.synthesis)
    cosim = int(args.co_simulation)
    validation = int(args.validation)
    export = int(args.export)
    vsynth = int(args.vivado_synthesis)
    if args.all:
        csim = synth = cosim = validation = export = vsynth = 1
    
    yamlConfig = hls4ml.converters.parse_yaml_config(args.project + '/' + config_filename)

    # Check if vivado_hls is available
    if 'linux' in sys.platform or 'darwin' in sys.platform:
        backend = yamlConfig.get('Backend', 'Vivado')
        if backend in ['Vivado', 'VivadoAccelerator']:
            found = os.system('command -v vivado_hls > /dev/null')
            if found != 0:
                print('Vivado HLS installation not found. Make sure "vivado_hls" is on PATH.')
                sys.exit(1)
        elif backend == 'Intel':
            raise NotImplementedError
        elif backend == 'Mentor':
            raise NotImplementedError
        else:
            raise Exception('Backend values can be [Vivado, Intel, Mentor, VivadoAccelerator]')
    
    os.system('cd {dir} && vivado_hls -f build_prj.tcl "reset={reset} csim={csim} synth={synth} cosim={cosim} validation={validation} export={export} vsynth={vsynth}"'
        .format(dir=args.project, reset=reset, csim=csim, synth=synth, cosim=cosim, validation=validation, export=export, vsynth=vsynth))

def report(args):
    if args.project is None:
        print('Project directory (-p or --project) must be provided.')
        sys.exit(1)
    
    hls4ml.report.read_vivado_report(args.project, args.full)

if __name__== "__main__":
    main()
