# AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/99_test_base.ipynb (unless otherwise specified).

__all__ = ['TestBase']

# Cell

import shutil
import tempfile

from .predefined_problems import (
    get_weibo_cws_fn, get_weibo_fake_cls_fn, get_weibo_fake_multi_cls_fn,
    get_weibo_fake_ner_fn, get_weibo_masklm)
from .params import BaseParams

class TestBase():
    def __init__(self):
        self.setUp()

    def setUp(self) -> None:
        self.tmpfiledir = tempfile.mkdtemp()
        self.tmpckptdir = tempfile.mkdtemp()
        self.prepare_params()

    def tearDown(self) -> None:
        shutil.rmtree(self.tmpfiledir)
        shutil.rmtree(self.tmpckptdir)

    def prepare_params(self):

        self.problem_type_dict = {
            'weibo_fake_ner': 'seq_tag',
            'weibo_cws': 'seq_tag',
            'weibo_fake_multi_cls': 'multi_cls',
            'weibo_fake_cls': 'cls',
            'weibo_masklm': 'masklm'
        }

        self.processing_fn_dict = {
            'weibo_fake_ner': get_weibo_fake_ner_fn(file_path='/data/bert-multitask-learning/data/ner/weiboNER*'),
            'weibo_cws': get_weibo_cws_fn(file_path='/data/bert-multitask-learning/data/ner/weiboNER*'),
            'weibo_fake_cls': get_weibo_fake_cls_fn(file_path='/data/bert-multitask-learning/data/ner/weiboNER*'),
            'weibo_fake_multi_cls': get_weibo_fake_multi_cls_fn(file_path='/data/bert-multitask-learning/data/ner/weiboNER*'),
            'weibo_masklm': get_weibo_masklm(file_path='/data/bert-multitask-learning/data/ner/weiboNER*')
        }
        self.params = BaseParams()
        self.params.tmp_file_dir = self.tmpfiledir
        self.params.ckpt_dir = self.tmpckptdir
        self.params.transformer_model_name = 'voidful/albert_chinese_tiny'
        self.params.transformer_config_name = 'voidful/albert_chinese_tiny'
        self.params.transformer_tokenizer_name = 'voidful/albert_chinese_tiny'
        self.params.transformer_tokenizer_loading = 'BertTokenizer'
        self.params.transformer_config_loading = 'AlbertConfig'

        self.params.add_multiple_problems(
            problem_type_dict=self.problem_type_dict, processing_fn_dict=self.processing_fn_dict)
