"""
Model loader file
"""
import os

import gym
from stable_baselines3 import PPO


def load_model(env: gym.Env, filepath: str, model_type: str, policy: str) -> PPO:
    """
    Loader function for the model
    """
    model_class = None
    if model_type == "PPO":
        model_class = PPO
    else:
        raise Exception(f"'{model_type}' is not supported yet")

    if os.path.isfile(filepath):
        model = model_class.load(filepath)
    else:
        model = model_class(policy, env)
        model.learn(total_timesteps=1e4)

        model.save(filepath)

    return model
