from yqn_config.base_config import BaseConfig
from yqn_pytorch_framework.train.base_dataset import BaseDataset
from ${module_name}.data.${module_name}_data_handler import load_files
import numpy as np


class ${class_name}Dataset(BaseDataset):

    def __init__(self, config: BaseConfig, data_flag):
        super(${class_name}Dataset, self).__init__(config)
        if data_flag == 'train':
            (self.input_features,
             self.label_list) = load_files(self.config,
                                           self.config.train_file_dir)
        else:
            (self.input_features,
             self.label_list) = load_files(self.config,
                                           self.config.val_file_dir)

    def get_item_size(self):
        return len(self.label_list)

    def get_item(self, index):
        # TODO fill ${module_name} Dataset get_item
        input_numpy = self.input_features[index].astype(dtype=np.float32)
        labels = np.array(self.label_list[index]).astype(dtype=np.float32)
        return (input_numpy, labels)
