#
#  Copyright (c) 2022 IBM Corp.
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

import os
import random
import numpy as np

from label_sleuth.models.core.models_background_jobs_manager import ModelsBackgroundJobsManager
from label_sleuth.models.core.model_api import ModelAPI, ModelStatus
from label_sleuth.models.core.prediction import Prediction


class RandomModel(ModelAPI):
    """
    Mock classification model that does not train, and returns random classification predictions.
    """
    def __init__(self, output_dir, models_background_jobs_manager: ModelsBackgroundJobsManager):
        super().__init__(models_background_jobs_manager)
        self.model_dir = os.path.join(output_dir, "random")
        os.makedirs(self.model_dir, exist_ok=True)

        self.model_id_to_random_seed = {}
        self.random_seed = -1

    def _train(self, model_id, train_data, model_params):
        seed = self.random_seed + 1
        self.model_id_to_random_seed[model_id] = seed
        self.random_seed = seed

    def _infer(self, model_id, items_to_infer):
        rand = random.Random(self.model_id_to_random_seed[model_id])
        scores = np.array([rand.random() for _ in range(len(items_to_infer))])
        labels = [score > 0.5 for score in scores]
        return [Prediction(label=label, score=score) for label, score in zip(labels, scores)]

    def get_model_status(self, model_id):
        if model_id in self.model_id_to_random_seed:
            return ModelStatus.READY
        return ModelStatus.ERROR

    def get_models_dir(self):
        return self.model_dir

    def delete_model(self, model_id):
        if model_id in self.model_id_to_random_seed:
            self.model_id_to_random_seed.pop(model_id)
