#Tabular Regression Prediction Runtime Code

import boto3
import pandas as pd
import os
import numpy as np
import onnxruntime as rt
import json

def get_model_onnx(runtimemodel_s3_filename="runtime_model.onnx"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id" +
                    "/runtime_model.onnx")
    model = rt.InferenceSession(obj.get()['Body'].read())
    return model

def _get_pyspark_modules():
    import pyspark
    import pyspark.ml
    import re
    
    pyspark_modules = ['ml', 'ml.feature', 'ml.classification', 'ml.clustering', 'ml.regression']

    models_modules_dict = {}

    for i in pyspark_modules:
        models_list = [j for j in dir(eval('pyspark.'+i, {'pyspark': pyspark})) if callable(getattr(eval('pyspark.'+i, {'pyspark': pyspark}), j))]
        models_list = [j for j in models_list if re.match('^[A-Z]', j)]

        for k in models_list: 
            models_modules_dict[k] = 'pyspark.'+i
    
    return models_modules_dict

def pyspark_model_from_string(model_type):
    import importlib

    models_modules_dict = _get_pyspark_modules()
    module = models_modules_dict[model_type]
    model_class = getattr(importlib.import_module(module), model_type)
    return model_class

def get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip"):
    import os
    import pickle
    import tempfile
    from io import BytesIO
    from pathlib import Path
    from zipfile import ZipFile

    #create temporary folder
    temp_dir = tempfile.gettempdir()
    
    # there are some other zip files on temp_dir that might affect import process
    temp_dir = os.path.join(temp_dir, "$unique_model_id")
    os.makedirs(temp_dir, exist_ok=True)

    s3 = boto3.resource("s3")
    bucket = s3.Bucket("$bucket_name")
    
    zip_obj = s3.Object(bucket_name="$bucket_name",
                        key="$unique_model_id/runtime_preprocessor.zip")
    buffer = BytesIO(zip_obj.get()["Body"].read())
    z = ZipFile(buffer)
    # Extract all the contents of zip file in temp directory
    z.extractall(temp_dir)
    
    # Then import all pkl files you want from bucket (need to generate this list from
    # function globals
    pickle_file_list = []
    zip_file_list = []
    for file in os.listdir(temp_dir):
        if file.endswith(".pkl"):
            pickle_file_list.append(os.path.join(temp_dir, file))
        if file.endswith(".zip"):
            zip_file_list.append(os.path.join(temp_dir, file))
    
    for i in pickle_file_list: 
        objectname = str(os.path.basename(i)).replace(".pkl", "")
        objects = { objectname: "" }
        globals()[objectname] = pickle.load(open(str(i), "rb"))
    
    # Need spark session and context to instantiate model object
    # zip_file_list is only used by pyspark
    if len(zip_file_list):
        from pyspark.sql import SparkSession
        
        spark = SparkSession \
            .builder \
            .appName('Pyspark Model') \
            .getOrCreate()
    
        for i in zip_file_list:
            objectnames = str(os.path.basename(i)).replace(".zip", "").split("__")
            dir_path = i.replace(".zip", "")
            Path(dir_path).mkdir(parents=True, exist_ok=True)
              
            # Create a ZipFile Object and load module.zip in it
            with ZipFile(i, 'r') as zipObj:
                # Extract all the contents of zip file in current directory
                zipObj.extractall(dir_path)

            preprocessor_type = objectnames[0].split("_")[0]
            objectname = objectnames[1]
            preprocessor_class = pyspark_model_from_string(preprocessor_type)
            if preprocessor_type == "PipelineModel":
                print(preprocessor_class)
                preprocessor_model = preprocessor_class(stages=None)
            else:
                preprocessor_model = preprocessor_class()

            preprocessor_model = preprocessor_model.load(dir_path)
            globals()[objectname] = preprocessor_model

    # First import preprocessor function to session from preprocessor.py
    exec(open(os.path.join(temp_dir, 'preprocessor.py')).read(),globals())
    return preprocessor

def get_runtimedata(runtimedata_s3_filename="runtime_data.json"):

      s3 = boto3.resource('s3')
      obj = s3.Object("$bucket_name", "$unique_model_id"+"/"+runtimedata_s3_filename)
      runtime_data = json.load(obj.get()['Body'])

      return runtime_data


runtime_data=get_runtimedata(runtimedata_s3_filename="runtime_data.json")

preprocessor_type=runtime_data["runtime_preprocessor"]

runtime_model=runtime_data["runtime_model"]["name"]

# Load model
model=get_model_onnx(runtimemodel_s3_filename='runtime_model.onnx')

# Load preprocessor
preprocessor=get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip")


def predict(event,model,preprocessor):
    body = event["body"]
    import six
    if isinstance(event["body"], six.string_types):
        body = json.loads(event["body"])
        print(body["data"])
        bodydata = pd.DataFrame.from_dict(body["data"])
    else:
        print(body["data"])
        bodydata = pd.DataFrame.from_dict(body["data"])
        print(bodydata)

    input_name = model.get_inputs()[0].name
    print(input_name)

    #Preprocess data
    input_data = preprocessor(bodydata).astype('float32') #needs to be float32
    print(input_data)
    
    # Generate prediction using preprocessed input data

    res=model.run(None,  {input_name: input_data})
    
    result = res[0].tolist()[0]

    return result

def handler(event, context):
        result = predict(event,model,preprocessor)
        return {"statusCode": 200,
        "headers": {
        "Access-Control-Allow-Origin" : "*",
        "Access-Control-Allow-Credentials": True,
        "Allow" : "GET, OPTIONS, POST",
        "Access-Control-Allow-Methods" : "GET, OPTIONS, POST",
        "Access-Control-Allow-Headers" : "*"
        },
        "body": json.dumps(result)}
