# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/07_Agents/12n_agents.dqn.dueling.ipynb.

# %% auto 0
__all__ = ['DuelingHead']

# %% ../../../nbs/07_Agents/12n_agents.dqn.dueling.ipynb 3
# Python native modules
import os
from collections import deque
from typing import *
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torch.utils.data.dataloader_experimental import DataLoader2
from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
from torchdata.dataloader2.graph import find_dps,traverse,DataPipe,replace_dp,remove_dp
# Local modules
import torch
from torch.nn import *
import torch.nn.functional as F
from torch.optim import *

from ...torch_core import *

from ...core import *
from ..core import *
from ...pipes.core import *
from ...data.block import *
from ...memory.experience_replay import *
from ..core import *
from ..discrete import *
from ...loggers.core import *
from ...loggers.vscode_visualizers import *
from ...learner.core import *
from .basic import *
from .target import *

# %% ../../../nbs/07_Agents/12n_agents.dqn.dueling.ipynb 6
class DuelingHead(Module):
    def __init__(self,
            hidden:int, # Input into the DuelingHead, likely a hidden layer input
            n_actions:int, # Number/dim of actions to output
            lin_cls=Linear
        ):
        super().__init__()
        self.val=lin_cls(hidden,1)
        self.adv=lin_cls(hidden,n_actions)

    def forward(self,xi):
        val,adv=self.val(xi),self.adv(xi)
        xi=val.expand_as(adv)+(adv-adv.mean()).squeeze(0)
        return xi
