# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10c_agents.dqn.double.ipynb.

# %% auto 0
__all__ = ['DoubleDQNTrainer']

# %% ../nbs/10c_agents.dqn.double.ipynb 3
# Python native modules
import os
from collections import deque
# Third party libs
import torch
from torch.nn import *
from fastcore.all import *
from fastai.learner import *
from fastai.torch_basics import *
from fastai.torch_core import *
from fastai.callback.all import *
# Local modules
from ...fastai.data.block import *
from ...agent import *
from ...core import *
from .core import *
from .targets import *
from ...memory.experience_replay import *

# %% ../nbs/10c_agents.dqn.double.ipynb 5
class DoubleDQNTrainer(DQNTargetTrainer):
    def after_pred(self):
        self.learn.yb=self.yb[0]
        self._yb=({k:v.clone() for k,v in self.yb.items()},)
        self.learn.done_mask=self.yb['done'].reshape(-1,)
        chosen_actions=self.learn.next_q=self.model.model(self.yb['next_state']).argmax(dim=1).reshape(-1,1)
        self.learn.next_q=self.target_model(self.yb['next_state']).gather(1,chosen_actions)
        self.learn.next_q[self.done_mask]=0
        self.learn.targets=self.yb['reward']+self.learn.next_q*(self.discount**self.n_steps)
        self.learn.pred=self.learn.model.model(self.yb['state'])
        t_q=self.pred.clone()
        t_q.scatter_(1,self.yb['action'],self.targets)
        self.learn.yb=(t_q,)
