#!/usr/bin/env python3

# This file is part of tf-plan.

# tf-plan is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# tf-plan is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with tf-plan. If not, see <http://www.gnu.org/licenses/>.


import argparse


def parse_args():
    description = 'Planning via gradient-based optimization in TensorFlow.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('rddl', type=str, help='RDDL filepath')
    parser.add_argument(
        '-b', '--batch-size',
        type=int, default=100,
        help='number of trajectories in a batch (default=100)'
    )
    parser.add_argument(
        '-hr', '--horizon',
        type=int, default=40,
        help='number of timesteps (default=40)'
    )
    parser.add_argument(
        '-e', '--epochs',
        type=int, default=1000,
        help='number of timesteps (default=1000)'
    )
    parser.add_argument(
        '-lr', '--learning-rate',
        type=float, default=0.01,
        help='optimizer learning rate (default=0.001)'
    )
    parser.add_argument(
        '--viz',
        default='generic',
        choices=('generic', 'navigation'),
        help='type of visualizer (default=generic)'
    )
    parser.add_argument(
        '-v', '--verbose',
        action='store_true',
        help='verbosity mode'
    )
    return parser.parse_args()


def read_file(path):
    with open(path, 'r') as f:
        return f.read()


def parse_rddl(path):
    from pyrddl.parser import RDDLParser
    parser = RDDLParser()
    parser.build()
    rddl = parser.parse(read_file(path))
    return rddl


def compile(rddl):
    from tfrddlsim.compiler import Compiler
    rddl2tf = Compiler(rddl, batch_mode=True)
    return rddl2tf


def optimize(compiler, batch_size, horizon, epochs, learning_rate=0.001):
    from tfplan.train.policy import OpenLoopPolicy
    from tfplan.train.optimizer import ActionOptimizer
    policy = OpenLoopPolicy(compiler, batch_size, horizon)
    policy.build('training')
    optimizer = ActionOptimizer(compiler, policy)
    optimizer.build(learning_rate)
    solution = optimizer.run(epochs)
    return solution


def evaluate(compiler, horizon, policy_variables):
    from tfplan.train.policy import OpenLoopPolicy
    from tfplan.test.evaluator import ActionEvaluator
    policy = OpenLoopPolicy(compiler, 1, horizon)
    policy.build('test', policy_variables)
    evaluator = ActionEvaluator(compiler, policy)
    trajectories = evaluator.run()
    return trajectories


def display(compiler, trajectories, visualizer, verbose=True):
    from tfrddlsim.viz.visualizer import BasicVisualizer
    from tfrddlsim.viz.navigation_visualizer import NavigationVisualizer
    if visualizer == 'generic':
        viz = BasicVisualizer(compiler, verbose)
    elif visualizer == 'navigation':
        viz = NavigationVisualizer(compiler, verbose)
    viz.render(trajectories)


if __name__ == '__main__':

    # parse CLI arguments
    args = parse_args()

    # read RDDL file
    rddl = parse_rddl(args.rddl)

    # compile RDDL to TensorFlow
    rddl2tf = compile(rddl)

    # optimize actions
    solution = optimize(rddl2tf, args.batch_size, args.horizon, args.epochs, args.learning_rate)

    # evaluate solution
    trajectories = evaluate(rddl2tf, args.horizon, solution)

    # render visualization
    display(rddl2tf, trajectories, args.viz)
