# AUTOGENERATED! DO NOT EDIT! File to edit: 05_DoorGym_gazebo_inference.ipynb (unless otherwise specified).

__all__ = ['download_model', 'init_model', 'inference', 'get_value', 'plot_curve']

# Cell
import torch
import os
import gdown
import yaml
from zipfile import ZipFile
import csv
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math

# Cell
def download_model(id, path, name):
    """Download a model.

    This function will download model from google cloud.

    Args:
        id : The file id of google cloud.
        path : The path where download model.
        name : The output file name.

    Return:
        model_path : The model path.

    """
    dataset_url = 'https://drive.google.com/u/1/uc?id=' + id
    dataset_name = name + ".pt"

    if not os.path.exists('model'):
        os.makedirs('model')
    model_path = os.path.join(path, "model", dataset_name)

    if not os.path.isfile(model_path):
        gdown.download(dataset_url, output = model_path, quiet=False)

    return model_path

# Cell
def init_model(model_path, state_dim):
    """Initial model.

    This function will load model and set need parameters.

    Args:
        model_path : The trained model path.
        state_dim : Dimension of state.

    Return：
        model : After initialization model.

    """
    actor_critic, ob_rms = torch.load(model_path)
    actor_critic = actor_critic.eval()
    actor_critic.nn = state_dim

    return actor_critic

# Cell
def inference(model, state, hidden_state):
    """Inference RL_mm in gazebo or real world.

    This function will use model and state to output next action.

    Args:
        model : The model after initalization.
        state : The environment and robot observation.
        hidden_state : The recurrent network setting.

    Return:
        action : The robot action.
        recurrent_hidden_state : The next hidden_state.

    """
    masks = torch.zeros(1, 1)
    with torch.no_grad():
        _, action, _, recurrent_hidden_states = model.act(
                        state, hidden_state, masks, deterministic=True)

    return action, recurrent_hidden_states

# Cell
def get_value(file_name):
    """

    This function will read csv file and output required value to plot curve.

    Args:
        file_name : The csv file
    Return:
        value : Output reward or metric value

    """

    value = []

    with open(file_name, newline='') as csvfile:

        rows = csv.DictReader(csvfile)

        for row in rows:
            value.append(float(row["Value"]))

    return np.array(value)

def plot_curve(reward, metric):
    """

    This function will plot curve from reward and metric.

    Args:
        reward : All reward value
        metric : All metric value
    Return:
        None

    """

    x = len(reward)
    x_axis = np.array([i for i in range(x)])
    plt.grid(True)

    plt.plot(x_axis, reward)

    plt.ylabel("Rewards")
    plt.xlabel("Episode")
    plt.show()

    x = len(metric)
    x_axis = np.array([i for i in range(x)])
    plt.grid(True)

    plt.plot(x_axis, metric)

    plt.ylabel("Door open success rate")
    plt.xlabel("Steps")
    plt.show()
