import torch
import torch.nn.functional as F
from torch import nn
from yqn_config.base_config import BaseConfig
import numpy as np


class ${class_name}Loss(nn.Module):

    def __init__(self, config: BaseConfig):
        super(${class_name}Loss, self).__init__()
        self.config = config

    def forward(self, predicts, targets):
        predicts = torch.sigmoid(predicts)
        bce_loss = F.binary_cross_entropy_with_logits(predicts, targets, reduction='none')
        return bce_loss.mean()
