#! /usr/bin/env python3

# This file is part of MDP-ProbLog.

# MDP-ProbLog is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# MDP-ProbLog is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with MDP-ProbLog.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import time

from mdpproblog.mdp import MDP
from mdpproblog.value_iteration import ValueIteration

def parse():
	parser = argparse.ArgumentParser()
	parser.add_argument("domain", help="path to MDP domain file")
	parser.add_argument("instance", help="path to MDP instance file")
	parser.add_argument("-g", "--gamma",   type=float, default=0.9, help="discount factor (default=0.9)")
	parser.add_argument("-e", "--epsilon", type=float, default=0.1, help="maximum error   (default=0.1)")
	return parser.parse_args()

def load_model(domain, instance):
	model = ''
	with open(domain, 'r') as domain:
		model += domain.read()
	with open(instance, 'r') as instance:
		model += instance.read()
	return MDP(model)

def solve_model(mdp, gamma, epsilon):
	vi = ValueIteration(mdp)
	return vi.run(gamma, epsilon)

def print_solution(V, policy, iterations):
	print()
	for state, value in V.items():
		state = ', '.join(["{f}={v}".format(f=f, v=v) for f, v in state])
		print("Value({state}) = {value:.3f}".format(state=state, value=value))
	print()
	for state, action in policy.items():
		state = ', '.join(["{f}={v}".format(f=f, v=v) for f, v in state])
		print("Policy({state}) = {action}".format(state=state, action=action))
	print()

if __name__ == '__main__':
	args = parse()
	mdp = load_model(args.domain, args.instance)
	start = time.clock()
	V, policy, iterations = solve_model(mdp, args.gamma, args.epsilon)
	end = time.clock()
	uptime = end - start
	print_solution(V, policy, iterations)
	print(">> Value iteration converged in {time:.3f}sec after {it} iterations.".format(time=uptime, it=iterations))
	print(">> Average time per iteration = {0:.3f}sec.".format(uptime/iterations))
