# Copyright 2020 The tfaip authors. All Rights Reserved.
#
# This file is part of tfaip.
#
# tfaip is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.
#
# tfaip is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# tfaip. If not, see http://www.gnu.org/licenses/.
# ==============================================================================
from dataclasses import dataclass, field
from typing import List, Dict, Optional

from dataclasses_json import dataclass_json, config, LetterCase

from tfaip.base.data.databaseparams import DataBaseParams
from tfaip.base.evaluator.params import EvaluatorParams
from tfaip.base.model.modelbaseparams import ModelBaseParams
from tfaip.util.argumentparser import dc_meta
from tfaip.util.versioning import get_commit_hash
from tfaip import __version__


@dataclass_json
@dataclass
class ScenarioBaseParams:
    """
    Define the global params of a scenario
    contains model_params and data_params of the model

    NOTE: add @dataclass_json and @dataclass annotations to inherited class
    """
    debug_graph_construction: bool = field(default=False, metadata=dc_meta(
        help="Build the graph in pure eager mode to debug the graph construction on real data"
    ))
    debug_graph_n_examples: int = field(default=1, metadata=dc_meta(
        help="number of examples to take from the validation set for debugging, -1 = all"
    ))

    print_eval_limit: int = field(default=10, metadata=dc_meta(
        help="Number of evaluation examples to print per evaluation, use -1 to print all"
    ))

    tensorboard_logger_history_size: int = field(default=5, metadata=dc_meta(
        help="Number of instances to store for outputing into tensorboard. Default (last n=5)"
    ))

    export_frozen: bool = field(default=False, metadata=dc_meta(
        help="Export the frozen graph alongside the saved model"
    ))
    export_serve: bool = field(default=True, metadata=dc_meta(
        help="Export the serving model (saved model format)"
    ))

    model_params: ModelBaseParams = field(default_factory=ModelBaseParams)
    data_params: DataBaseParams = field(default_factory=DataBaseParams)
    evaluator_params: EvaluatorParams = field(default_factory=EvaluatorParams)

    # Additional export params
    export_net_config_: bool = True
    net_config_filename_: str = 'net_config.json'
    frozen_dir_: str = 'frozen'
    frozen_filename_: str = 'frozen_model.pb'
    default_serve_dir_: str = 'serve'
    additional_serve_dir_: str = 'additional'
    trainer_params_filename_: str = 'trainer_params.json'
    scenario_params_filename_: str = 'scenario_params.json'

    scenario_base_path_: Optional[str] = None
    scenario_module_: Optional[str] = None
    id_: Optional[str] = None

    tfaip_commit_hash_: str = field(default_factory=get_commit_hash)
    tfaip_version_: str = __version__


@dataclass
class NetConfigNodeSpec:
    shape: List[str]
    dtype: str
    node_frozen: str = field(metadata=config(letter_case=LetterCase.CAMEL))
    node_serve: str = field(metadata=config(letter_case=LetterCase.CAMEL))


@dataclass_json
@dataclass
class NetConfigParamsBase:
    id_model: str = field(metadata=config(letter_case=LetterCase.CAMEL))
    id_frozen: str = field(metadata=config(letter_case=LetterCase.CAMEL))
    id_serve: str = field(metadata=config(letter_case=LetterCase.CAMEL))
    in_nodes: Dict[str, NetConfigNodeSpec] = field(metadata=config(letter_case=LetterCase.CAMEL))
    out_nodes: Dict[str, NetConfigNodeSpec] = field(metadata=config(letter_case=LetterCase.CAMEL))

    tf_version: str = field(metadata=config(letter_case=LetterCase.CAMEL))
    tfaip_commit_hash: str = field(default_factory=get_commit_hash, metadata=config(letter_case=LetterCase.CAMEL))
    tfaip_version: str = field(default=__version__, metadata=config(letter_case=LetterCase.CAMEL))
