from yqn_config.base_config import BaseConfig
from yqn_pytorch_framework.train.base_engine import BaseModelEngine
from yqn_pytorch_framework.train.base_train import BaseModelTrain

from ${module_name}.model.pytorch_${module_name}_model import ${class_name}Model
from ${module_name}.train.pytorch_${module_name}_dataset import ${class_name}Dataset


class ${class_name}ModelTrain(BaseModelTrain):
    def __init__(self, engine: BaseModelEngine):
        super(${class_name}ModelTrain, self).__init__(engine)

    def load_model(self, model_config: BaseConfig):
        model = ${class_name}Model(model_config)
        return model

    def load_data(self, config: BaseConfig, data_flag='train'):
        return ${class_name}Dataset(config, data_flag)
