# AUTOGENERATED! DO NOT EDIT! File to edit: 09_CurriculumSchedulerOneMachine.ipynb (unless otherwise specified).

__all__ = ['CurriculumSchedulerOneMachine']

# Cell
import os
import sys
import signal
import subprocess
import time
import numpy as np
import yaml
import random

# Cell
class CurriculumSchedulerOneMachine():
	def __init__(self, yaml_path, command_dic, waiting_second=40):
		self.command_dic = command_dic
		self.waiting_second = waiting_second
		with open(yaml_path, 'r') as f:
			self.yaml_data = yaml.load(f, Loader=yaml.CLoader)
			self.yaml_sort = sorted(self.yaml_data)
		self.success_queue = []
		self.current_task_index = 0
		self.process = None
		self.__switch_task()

	def update(self, success):
		current_task = self.yaml_sort[self.current_task_index]
		current_task_env = self.yaml_data[self.yaml_sort[self.current_task_index]]["environment"]
		window_size = self.yaml_data[current_task]["window_size"]
		threshold = self.yaml_data[current_task]["threshold"]
		self.success_queue.append(success)
		rate = np.mean(self.success_queue)
		print("------ CurriculumSchedulerOneMachine ------")
		print("Current Task:",current_task,current_task_env)
		print("window_size:",window_size, ", Current Buffer size:",len(self.success_queue))
		print("Threshold:",threshold, ", Current Avg success rate:",rate)

		if len(self.success_queue) >= window_size:
			self.success_queue = self.success_queue[-window_size:]
			if (self.current_task_index<len(self.yaml_sort)-1) and rate>threshold:
				self.current_task_index += 1
				self.success_queue = []
				self.__switch_task()
		start_goal_pair = self.yaml_data[current_task]['pairs'][random.randint(0,10)]
		action_scale = self.yaml_data[current_task]['action_scale']
		return start_goal_pair, action_scale

	def __switch_task(self):
		if not self.process==None:
			os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)  # Send the signal to all the process groups
			for i in range(self.waiting_second):
				time.sleep(1)
				print("waiting", self.waiting_second-i)

		current_task_env = self.yaml_data[self.yaml_sort[self.current_task_index]]["environment"]
		command_str = self.command_dic[current_task_env]
		self.process = subprocess.Popen(command_str, stdout=subprocess.PIPE,shell=True, preexec_fn=os.setsid)
		print('\033[92m@@@@@@@@@ \\\ Level Up /// @@@@@@@@@\033[0m')
		for i in range(self.waiting_second):
			time.sleep(1)
			print("waiting", self.waiting_second-i)