# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/05_data.block.ipynb (unless otherwise specified).

__all__ = ['DQN', 'TestDatasetNoModule', 'TestDataset', 'DQN', 'init_experience', 'FakeAgent', 'ExperienceSource',
           'SourceDataset', 'TestDataset']

# Cell
# Python native modules
import os
# Third party libs
from fastcore.all import *
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from torch.utils.data import Dataset
from torch import nn


import numpy as np
import gym
import time,sys
import torch.multiprocessing as mp

# Local modules

# Cell
class DQN(Module):
    def __init__(self):
        self.policy=nn.Sequential(
            nn.Linear(4,50),
            nn.ReLU(),
            nn.Linear(50,2),
            nn.ReLU()
        )

    def forward(self,x):
        return torch.argmax(self.policy(x),dim=0)

# Cell
class TestDatasetNoModule(IterableDataset):
    def __init__(self,device='cpu'):
        self.device=device
        self.pids=[os.getpid()]

    def __len__(self): return 10
    def __getitem__(self,idx):
        print('starting')
        sys.stdout.flush()
        try:
            print('making env')
            print('pid is: ',os.getpid(),flush=True)
            self.pids.append(os.getpid())
#             env=gym.make('CartPole-v1')
            return torch.rand(1,4).to(device=self.device)
        except Exception as e:
            print(e,'it crashed lol omg')
            sys.stdout.flush()
            return torch.rand(1,4).to(device=self.device)

# Cell

class TestDataset(Dataset):
    def __init__(self,policy,device='cpu'):
        self.policy=policy
        self.device=device
        self.policy.to(device=self.device)

        self.pids=mp.Queue()
        self.pids.put(os.getpid())
        self.envs=mp.Queue()
        self.envs.put(os.getpid())

        self.env=gym.make('CartPole-v1')

    def __len__(self): return 100
    def __getitem__(self,idx):
        self.pids.put(os.getpid())
        self.envs.put(id(self.env))
        try:
#             env=gym.make('CartPole-v1')
            next_state=self.env.reset()
            print(id(self.env),' ')
            print('pid is: ',os.getpid(),flush=True)
            next_state, r, is_done, _=self.env.step(self.policy(Tensor(next_state).to(device=self.device)).cpu().numpy())
            if is_done:next_state=self.env.reset()
            return Tensor(next_state).to(device=self.device)
        except Exception as e:
            print(e)
            return Tensor(np.random.rand(0,4)).to(device=self.device)

# Cell
# Python native modules
import os
from collections import deque
from time import sleep
# Third party libs
from fastcore.all import *
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from torch.utils.data import Dataset
from torch import nn
import torch

# Local modules
from ..core import *

# Cell
class DQN(Module):
    def __init__(self):
        self.policy=nn.Sequential(
            nn.Linear(4,50),
            nn.ReLU(),
            nn.Linear(50,2),
            nn.ReLU()
        )

    def forward(self,x):
        return torch.argmax(self.policy(x),dim=0)

# Cell
def init_experience(but='',**kwargs):
    "Returns dictionary with default values that can be overridden."
    experience=D(
        state=0,action=0,next_state=0,reward=[0],done=False,
        step=0,steps=0,n_env=0,image=0
    )
    for s in but.split(','):
        if s in experience:del experience[s]
    return D(merge(experience,kwargs))

# Cell
def _state2experience(s,**kwargs):   return init_experience(state=s,**kwargs)
def _env_reset(o):                   return o.reset()
def _env_seed(o,seed):               return o.seed(seed)
def _env_render(o,mode='rgb_array'): return [o.render(mode=mode).copy()]
def _env_step(o,*args,**kwargs):     return o.step(*args,**kwargs)

class FakeAgent:
    def __init__(self,action_space): store_attr()
    def __call__(self,state,**kwargs):
        return L([self.action_space.sample() for _ in range(state.shape[0])]),D(kwargs)

class ExperienceSource(Stateful):
    _stateattrs=('pool',)
    def __init__(self,env:str,agent=None,n_envs:int=1,steps_count:int=1,steps_delta:int=1,
                 seed:int=None,render=None,num_workers=0,but='',**kwargs):
        store_attr()
        self.env_kwargs=kwargs
        self.pool=L()
        if self.render is None: self.but+=',image'

    def _init_state(self):
        "Inits the histories, experiences, and the environment pool when sent to a `Process`"
        self.history,self.pool=L((deque(maxlen=self.steps_count),
                                  gym.make(self.env,**self.env_kwargs))
                                  for _ in range(self.n_envs)).zip().map(L)
        self.pool.map(_env_seed,seed=self.seed)
        if self.agent is None: self.agent=FakeAgent(self.pool[0].action_space)
        self.reset_all()

    def reset_all(self):
        self.experiences=self.pool.map(_env_reset)
        self.experiences=self.experiences.map(_state2experience,but=self.but)
        self.experiences=sum(self.experiences[1:],self.experiences[0])
        self.attempt_render(self.experiences)

    def attempt_render(self,experiences,indexes=None):
        if self.render is not None:
            pool=self.pool if indexes is None else self.pool[indexes]
            renders=pool.map(_env_render,mode=self.render)
            # No idea why we have to do this, but multiprocessing hangs forever otherwise
            if self.num_workers>0:sleep(0.1)
            experiences['image']=np.vstack(renders).astype(float)

    def __iter__(self):
        "Iterates through a list of environments."
        if not self.pool:self._init_state()
        while True:
#             try:
#             print(self.experiences)
            not_done_idxs=self.experiences.argwhere('done',L.argwhere,lambda x:x==False)
            not_done_experiences=self.experiences.filter(indexes=not_done_idxs)
            actions,experiences=self.agent(**not_done_experiences)
            step_res=self.pool[not_done_idxs].zipwith(actions).starmap(_env_step)
            next_states,rewards,dones,info=step_res.zip().map(L)

            self.attempt_render(self.experiences,not_done_idxs)

            experiences=D(merge(not_done_experiences,experiences,
                                D(next_state=next_states,reward=rewards,done=dones)))
            if self.n_envs>1:
                out={k:experiences[k][not_done_idxs[0]] for k in experiences}
                split_experiences=parallel(partial(experiences.subset),not_done_idxs,
                                           threadpool=True,n_workers=2,progress=False)
                print(split_experiences[0]['state'].shape,out['state'].shape)
                yield split_experiences
            else:
                yield experiences # {'actions':actions}
#             except ValueError:
#                 self.reset_all()

add_docs(ExperienceSource,
        """Iterates through `n_envs` of `env` feeding experience or states into `agent`.
           If `agent` is None, then random actions will be taken instead.
           It will return `steps_count` experiences every `steps_delta`.
           At the end of an env, it will return `steps_count-1` experiences per next. """,
        reset_all="resets the envs and experience",
        attempt_render="Updates `experiences` with images if `render is not None`. Optionally indexes can be passed.")

# Cell
class SourceDataset(IterableDataset):
    "Iterates through a `source` object. Allows for re-initing source connections when `num_workers>0`"
    def __init__(self,source=None): store_attr('source')
    def __iter__(self):             return iter(self.source)
    def wif(self):                  self.source._init_state()

# Cell
class TestDataset(IterableDataset):
    def __init__(self,start=1,end=10,policy=None,device='cpu',n_envs=1):
        store_attr('start,end,policy,device,n_envs')

    def init_envs(self,n):
        self.envs=[gym.make('CartPole-v1') for i in range(n)]

    def __iter__(self):
        worker_info=torch.utils.data.get_worker_info()

        if worker_info is None:  # single-process data loading, return the full iterator
            self.init_envs(self.n_envs)
        else:  # in a worker process
            # split workload
            per_worker=int(math.ceil(self.n_envs/worker_info.num_workers))
            self.init_envs(per_worker)
        return iter(range(iter_start, iter_end))