# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/07_Agents/12m_agents.dqn.double.ipynb.

# %% auto 0
__all__ = ['DoubleQCalc']

# %% ../../../nbs/07_Agents/12m_agents.dqn.double.ipynb 3
# Python native modules
import os
from collections import deque
from typing import *
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torch.utils.data.dataloader_experimental import DataLoader2
from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
from torchdata.dataloader2.graph import find_dps,traverse,DataPipe,replace_dp,remove_dp
# Local modules
import torch
from torch.nn import *
import torch.nn.functional as F
from torch.optim import *

from ...torch_core import *

from ...core import *
from ..core import *
from ...pipes.core import *
from ...data.block import *
from ...memory.experience_replay import *
from ..core import *
from ..discrete import *
from ...loggers.core import *
from ...loggers.jupyter_visualizers import *
from ...learner.core import *
from .basic import *
from .target import *

# %% ../../../nbs/07_Agents/12m_agents.dqn.double.ipynb 6
class DoubleQCalc(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe=None):
        self.source_datapipe = source_datapipe
        if source_datapipe is not None: self.learner = find_dp(traverse(self),LearnerBase)
                
    def __iter__(self):
        for batch in self.source_datapipe:
            self.learner.done_mask = batch.terminated.reshape(-1,)
            with torch.no_grad():
                chosen_actions = self.learner.model(batch.next_state).argmax(dim=1).reshape(-1,1)
                self.learner.next_q = self.learner.target_model(batch.next_state).gather(1,chosen_actions)
            self.learner.next_q[self.learner.done_mask] = 0
            yield batch
            
    @classmethod
    def replace_dp(cls,old_dp=(QCalc,TargetModelQCalc)) -> Callable[[DataPipe],DataPipe]:
        def _replace_dp(pipe):
            found = False
            for _old_dp in old_dp:
                try:
                    old_dp_instance = find_dp(traverse(pipe),_old_dp)
                    v = replace_dp(
                        traverse(pipe),
                        old_dp_instance,
                        cls(old_dp_instance.source_datapipe)
                    )
                    pipe = list(v.values())[0][0]
                    found = True
                except LookupError: pass
            if not found: 
                warn(f'Unable to find: {old_dp} in {cls} given {traverse(pipe)}')
            return pipe
        return _replace_dp
