from yqn_config.base_config import BaseConfig
from yqn_pytorch_framework.train.base_engine import BaseModelEngine
from ${module_name}.train.pytorch_${module_name}_acc import ${class_name}Acc
from ${module_name}.train.pytorch_${module_name}_loss import ${class_name}Loss


class ${class_name}Engine(BaseModelEngine):
    def __init__(self, config: BaseConfig):
        super(${class_name}Engine, self).__init__(config)
        self.loss_function = ${class_name}Loss(config)
        self.acc_function = ${class_name}Acc(config)

    def criterion(self, outputs, labels):
        return [self.loss_function(outputs, labels)]

    def accuracy(self, outputs, labels):
        return [self.acc_function(outputs, labels)]

    def load_inputs(self, dataset):
        return [dataset[0]]

    def load_labels(self, dataset):
        return [dataset[1]]
