# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/07_Agents/12l_agents.dqn.asynchronous.ipynb.

# %% auto 0
__all__ = ['ModelSubscriber', 'ModelPublisher']

# %% ../../../nbs/07_Agents/12l_agents.dqn.asynchronous.ipynb 3
# Python native modules
import os
from typing import *
from collections import deque
from copy import deepcopy
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
import torch.multiprocessing as mp
import torch
from torch.nn import *
import torch.nn.functional as F
from torch.optim import *
from torchdata.dataloader2.graph import find_dps,traverse,replace_dp,DataPipe
# Local modules
from ...torch_core import *
from ...core import *
from ..core import *
from ...pipes.core import *
from ...data.block import *
from ...memory.experience_replay import *
from ..core import *
from ..discrete import *
from ...loggers.core import *
from ...loggers.jupyter_visualizers import *
from ...learner.core import *
from .basic import *
from ...data.dataloader2 import *
from torchdata.dataloader2 import DataLoader2,DataLoader2Iterator

# %% ../../../nbs/07_Agents/12l_agents.dqn.asynchronous.ipynb 7
class ModelSubscriber(dp.iter.IterDataPipe):
    "If an agent is passed to another process and 'spawn' start method is used, then this module is needed."
    def __init__(self,source_datapipe): 
        super().__init__()
        self.source_datapipe = source_datapipe
        self.model = find_dp(traverse(self.source_datapipe,only_datapipe=True),AgentBase).model
        
    def __iter__(self):
        for x in self.source_datapipe:
            if type(x)==GetInputItemRequest and x.key.startswith('model_state_dict_publish_'):
                self.model.load_state_dict(x.value)
                continue
            yield x

    @classmethod
    def insert_dp(cls,old_dp=InputInjester) -> Callable[[DataPipe],DataPipe]:
        def _insert_dp(pipe):
            v = replace_dp(
                traverse(pipe,only_datapipe=True),
                find_dp(traverse(pipe,only_datapipe=True),old_dp),
                cls(find_dp(traverse(pipe,only_datapipe=True),old_dp))
            )
            return list(v.values())[0][0]
        return _insert_dp

# %% ../../../nbs/07_Agents/12l_agents.dqn.asynchronous.ipynb 8
class ModelPublisher(dp.iter.IterDataPipe):
    def __init__(self,
            source_datapipe,
            publish_freq:int=1,
            # Sometimes its not possible to share current model due to cuda issues.
            # `do_deepcopy` will copy and move the model to cpu in order to publish it.
            do_deepcopy:bool=False
        ):
        super().__init__()
        self.source_datapipe = source_datapipe
        self.model = find_dp(traverse(self,only_datapipe=True),LearnerBase).model
        self.publish_freq = publish_freq
        self.protocol_clients = []
        self._expect_response = []
        self.initialized = False
        self.do_deepcopy = do_deepcopy

    @classmethod
    def insert_dp(cls,old_dp=LoggerBasePassThrough,publish_freq=1) -> Callable[[DataPipe],DataPipe]:
        def _insert_dp(pipe):
            v = replace_dp(
                traverse(pipe,only_datapipe=True),
                find_dp(traverse(pipe,only_datapipe=True),old_dp),
                cls(find_dp(traverse(pipe,only_datapipe=True),old_dp),publish_freq=publish_freq)
            )
            return list(v.values())[0][0]
        return _insert_dp
 
    def _reset(self):
        for dl in find_dp(traverse(self,only_datapipe=True),LearnerBase).iterable:
            for q_wrapper in dl.datapipe.iterable.datapipes:
                self.protocol_clients.append(q_wrapper.protocol)
                self._expect_response.append(False)
        self.initialized = True

    def __iter__(self):
        for i,batch in enumerate(self.source_datapipe):
            # print('Got batch: ',batch)
            # print('running reset')
            if not self.initialized: self._reset()
            #  (this batch we should publish) and (there are protocols) and (there are some that are ready)
            if type(batch)==GetInputItemResponse and batch.value.startswith('model_state_dict_publish_'): 
                client_num = int(batch.value.replace('model_state_dict_publish_',''))

                if self._expect_response[client_num]:
                    self._expect_response[client_num] = False

                continue
            if i%self.publish_freq==0 and self.protocol_clients and not all(self._expect_response):
                with torch.no_grad():
                    # We need to deepcopy the model itself since `cpu` is an inplace op.
                    # We cant keep the model in cuda because mp.Manager passes around the 
                    # tensors too much and causes errors ref: https://github.com/pytorch/pytorch/issues/30401
                    # This is also why we cant just call state_dict directly. It returns references
                    # to cuda tensors.
                    if self.do_deepcopy:
                        state = deepcopy(self.model).to(device=self.device).state_dict()
                    else:
                        state = self.model.state_dict()

                        
                    for client_id,client in enumerate(self.protocol_clients):
                        if not self._expect_response[client_id]: 
                            # print('PUBLISHING!!!!')
                            client.request_input_item(
                                key=f'model_state_dict_publish_{client_id}',value=state
                            )
                            self._expect_response[client_id] = True
            yield batch
