#Text Regression Prediction Runtime Code

import boto3
import pandas as pd
import os
from io import BytesIO
import pickle
import numpy as np
import json
import onnxruntime as rt
import warnings
import six

def get_model_onnx(runtimemodel_s3_filename="runtime_model.onnx"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id" +
                    "/runtime_model.onnx")
    model = rt.InferenceSession(obj.get()['Body'].read())
    return model

def _get_pyspark_modules():
    import pyspark
    import pyspark.ml
    import re
    
    pyspark_modules = ['ml', 'ml.feature', 'ml.classification', 'ml.clustering', 'ml.regression']

    models_modules_dict = {}

    for i in pyspark_modules:
        models_list = [j for j in dir(eval('pyspark.'+i, {'pyspark': pyspark})) if callable(getattr(eval('pyspark.'+i, {'pyspark': pyspark}), j))]
        models_list = [j for j in models_list if re.match('^[A-Z]', j)]

        for k in models_list: 
            models_modules_dict[k] = 'pyspark.'+i
    
    return models_modules_dict

def pyspark_model_from_string(model_type):
    import importlib

    models_modules_dict = _get_pyspark_modules()
    module = models_modules_dict[model_type]
    model_class = getattr(importlib.import_module(module), model_type)
    return model_class

def get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip"):
    import os
    import pickle
    import tempfile
    from io import BytesIO
    from pathlib import Path
    from zipfile import ZipFile

    #create temporary folder
    temp_dir = tempfile.gettempdir()
    
    # there are some other zip files on temp_dir that might affect import process
    temp_dir = os.path.join(temp_dir, "$unique_model_id")
    os.makedirs(temp_dir, exist_ok=True)

    s3 = boto3.resource("s3")
    bucket = s3.Bucket("$bucket_name")
    
    zip_obj = s3.Object(bucket_name="$bucket_name",
                        key="$unique_model_id/runtime_preprocessor.zip")
    buffer = BytesIO(zip_obj.get()["Body"].read())
    z = ZipFile(buffer)
    # Extract all the contents of zip file in temp directory
    z.extractall(temp_dir)
    
    # Then import all pkl files you want from bucket (need to generate this list from
    # function globals
    pickle_file_list = []
    zip_file_list = []
    for file in os.listdir(temp_dir):
        if file.endswith(".pkl"):
            pickle_file_list.append(os.path.join(temp_dir, file))
        if file.endswith(".zip"):
            zip_file_list.append(os.path.join(temp_dir, file))
    
    for i in pickle_file_list: 
        objectname = str(os.path.basename(i)).replace(".pkl", "")
        objects = { objectname: "" }
        globals()[objectname] = pickle.load(open(str(i), "rb"))
    
    # Need spark session and context to instantiate model object
    # zip_file_list is only used by pyspark
    if len(zip_file_list):
        from pyspark.sql import SparkSession
    
        spark = SparkSession \
            .builder \
            .appName('Pyspark Model') \
            .getOrCreate()
    
        for i in zip_file_list:
            objectnames = str(os.path.basename(i)).replace(".zip", "").split("__")
            dir_path = i.replace(".zip", "")
            Path(dir_path).mkdir(parents=True, exist_ok=True)
              
            # Create a ZipFile Object and load module.zip in it
            with ZipFile(i, 'r') as zipObj:
                # Extract all the contents of zip file in current directory
                zipObj.extractall(dir_path)

            preprocessor_type = objectnames[0].split("_")[0]
            objectname = objectnames[1]
            preprocessor_class = pyspark_model_from_string(preprocessor_type)
            if preprocessor_type == "PipelineModel":
                print(preprocessor_class)
                preprocessor_model = preprocessor_class(stages=None)
            else:
                preprocessor_model = preprocessor_class()

            preprocessor_model = preprocessor_model.load(dir_path)
            globals()[objectname] = preprocessor_model

    # First import preprocessor function to session from preprocessor.py
    exec(open(os.path.join(temp_dir, 'preprocessor.py')).read(),globals())
    return preprocessor

def get_runtimedata(runtimedata_s3_filename="runtime_data.json"):

      s3 = boto3.resource('s3')
      obj = s3.Object("$bucket_name", "$unique_model_id"+"/"+runtimedata_s3_filename)
      runtime_data = json.load(obj.get()['Body'])

      return runtime_data


runtime_data=get_runtimedata(runtimedata_s3_filename="runtime_data.json")

preprocessor_type=runtime_data["runtime_preprocessor"]

runtime_model=runtime_data["runtime_model"]["name"]

# Load model
model=get_model_onnx(runtimemodel_s3_filename='runtime_model.onnx')

# Load preprocessor
preprocessor=get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip")


def predict(event,model,preprocessor):
    body = event["body"]
    if isinstance(event["body"], six.string_types):
        body = json.loads(event["body"])
        print(body["data"])
        bodynew = pd.Series(body["data"])
    else:
        print(body["data"])
        bodynew = pd.Series(body["data"])
        print(bodynew)
        
    sess=model
    def predict_classes(x): # adjusted from keras github code
            proba=x
            if proba.shape[-1] > 1:
              return proba.argmax(axis=-1)
            else:
              return (proba > 0.5).astype("int32")
    input_name = sess.get_inputs()[0].name

    input_data = preprocessor(bodynew).astype(np.float32) #needs to be float32

    res = sess.run(None,  {input_name: input_data})
    prob = res[0]
    return prob.tolist()[0]

def handler(event, context):
    result = predict(event,model,preprocessor)
    return {"statusCode": 200,
    "headers": {
    "Access-Control-Allow-Origin" : "*",
    "Access-Control-Allow-Credentials": True,
    "Allow" : "GET, OPTIONS, POST",
    "Access-Control-Allow-Methods" : "GET, OPTIONS, POST",
    "Access-Control-Allow-Headers" : "*"
    },
    "body": json.dumps(result)}
