# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model specification."""

import functools
import inspect

from tensorflow_examples.lite.model_maker.core.api import mm_export
from tensorflow_examples.lite.model_maker.core.task.model_spec import audio_spec
from tensorflow_examples.lite.model_maker.core.task.model_spec import image_spec
from tensorflow_examples.lite.model_maker.core.task.model_spec import object_detector_spec
from tensorflow_examples.lite.model_maker.core.task.model_spec import recommendation_spec
from tensorflow_examples.lite.model_maker.core.task.model_spec import text_spec


# A dict for model specs to make it accessible by string key.
MODEL_SPECS = {
    # Image classification
    'efficientnet_lite0': image_spec.efficientnet_lite0_spec,
    'efficientnet_lite1': image_spec.efficientnet_lite1_spec,
    'efficientnet_lite2': image_spec.efficientnet_lite2_spec,
    'efficientnet_lite3': image_spec.efficientnet_lite3_spec,
    'efficientnet_lite4': image_spec.efficientnet_lite4_spec,
    'mobilenet_v2': image_spec.mobilenet_v2_spec,
    'resnet_50': image_spec.resnet_50_spec,

    # Text classification
    'average_word_vec': text_spec.AverageWordVecModelSpec,
    'bert': text_spec.BertModelSpec,
    'bert_classifier': text_spec.BertClassifierModelSpec,
    'mobilebert_classifier': text_spec.mobilebert_classifier_spec,

    # Question answering
    'bert_qa': text_spec.BertQAModelSpec,
    'mobilebert_qa': text_spec.mobilebert_qa_spec,
    'mobilebert_qa_squad': text_spec.mobilebert_qa_squad_spec,

    # Audio classification
    'audio_browser_fft': audio_spec.BrowserFFTSpec,
    'audio_teachable_machine': audio_spec.BrowserFFTSpec,
    'audio_yamnet': audio_spec.YAMNetSpec,

    # Recommendation
    'recommendation_bow': recommendation_spec.recommendation_bow_spec,
    'recommendation_cnn': recommendation_spec.recommendation_cnn_spec,
    'recommendation_rnn': recommendation_spec.recommendation_rnn_spec,

    # Object detection
    'efficientdet_lite0': object_detector_spec.efficientdet_lite0_spec,
    'efficientdet_lite1': object_detector_spec.efficientdet_lite1_spec,
    'efficientdet_lite2': object_detector_spec.efficientdet_lite2_spec,
    'efficientdet_lite3': object_detector_spec.efficientdet_lite3_spec,
    'efficientdet_lite4': object_detector_spec.efficientdet_lite4_spec,
}

# List constants for supported models.
IMAGE_CLASSIFICATION_MODELS = [
    'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2',
    'efficientnet_lite3', 'efficientnet_lite4', 'mobilenet_v2', 'resnet_50'
]
TEXT_CLASSIFICATION_MODELS = [
    'bert_classifier', 'average_word_vec', 'mobilebert_classifier'
]
QUESTION_ANSWER_MODELS = ['bert_qa', 'mobilebert_qa', 'mobilebert_qa_squad']
AUDIO_CLASSIFICATION_MODELS = [
    'audio_browser_fft', 'audio_teachable_machine', 'audio_yamnet'
]
RECOMMENDATION_MODELS = [
    'recommendation_bow',
    'recommendation_rnn',
    'recommendation_cnn',
]
OBJECT_DETECTION_MODELS = [
    'efficientdet_lite0',
    'efficientdet_lite1',
    'efficientdet_lite2',
    'efficientdet_lite3',
    'efficientdet_lite4',
]
mm_export('model_spec.IMAGE_CLASSIFICATION_MODELS').export_constant(
    __name__, 'IMAGE_CLASSIFICATION_MODELS')
mm_export('model_spec.TEXT_CLASSIFICATION_MODELS').export_constant(
    __name__, 'TEXT_CLASSIFICATION_MODELS')
mm_export('model_spec.QUESTION_ANSWER_MODELS').export_constant(
    __name__, 'QUESTION_ANSWER_MODELS')
mm_export('model_spec.AUDIO_CLASSIFICATION_MODELS').export_constant(
    __name__, 'AUDIO_CLASSIFICATION_MODELS')
mm_export('model_spec.RECOMMENDATION_MODELS').export_constant(
    __name__, 'RECOMMENDATION_MODELS')
mm_export('model_spec.OBJECT_DETECTION_MODELS').export_constant(
    __name__, 'OBJECT_DETECTION_MODELS')


@mm_export('model_spec.get')
def get(spec_or_str):
  """Gets model spec by name or instance, and initializes by default."""
  if isinstance(spec_or_str, str):
    model_spec = MODEL_SPECS[spec_or_str]
  else:
    model_spec = spec_or_str

  if inspect.isclass(model_spec) or inspect.isfunction(
      model_spec) or isinstance(model_spec, functools.partial):
    return model_spec()
  else:
    return model_spec
