import numpy as np
import torchvision.transforms as transforms
import torch
from yqn_config.base_config import BaseConfig
from yqn_pytorch_framework.infer.base_infer import BaseModelInfer
import numpy as np

from ${module_name}.model.pytorch_${module_name}_model import ${class_name}Model
from ${module_name}.data.${module_name}_data_handler import *


class ${class_name}Infer(BaseModelInfer):
    def __init__(self, config: BaseConfig):
        super(${class_name}Infer, self).__init__(config)

    def get_infer_model_file(self):
        # TODO fill ${module_name} model path
        infer_model_path = self.config.get_infer_model_path('')
        return infer_model_path

    def load_model(self):
        model = ${class_name}Model(self.config)
        return model

    def get_input_size_without_batch(self):
        # TODO fill ${module_name} input size
        pass

    def pre_handle(self, input_features):
        # TODO fill ${module_name} pre_handle
        pass

    def format_output(self, outputs, input_features):
        # TODO fill ${module_name} format_output
        pass
