# Tabular time series regression

import boto3
import pandas as pd
import os
import numpy as np
import onnxruntime as rt
import json

def get_model_onnx(runtimemodel_s3_filename="runtime_model.onnx"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id" +
                    "/runtime_model.onnx")
    model = rt.InferenceSession(obj.get()['Body'].read())
    return model

def get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip"):
    import pickle
    from zipfile import ZipFile
    from io import BytesIO
    import os
    s3 = boto3.resource("s3")
    bucket = s3.Bucket("$bucket_name")
    
    zip_obj = s3.Object(bucket_name="$bucket_name", key="$unique_model_id/runtime_preprocessor.zip")
    buffer = BytesIO(zip_obj.get()["Body"].read())
    z = ZipFile(buffer)
    # Extract all the contents of zip file in current directory
    z.extractall("/tmp/")
    
    folderpath=os.path.dirname(os.path.abspath("/tmp/preprocessor.py"))
    file_name=os.path.basename("/tmp/preprocessor.py")

    #Then import all pkl files you want from bucket (need to generate this list from...
    # function globals)
    import os
    pickle_file_list=[]
    for file in os.listdir(folderpath):
          if file.endswith(".pkl"):
              pickle_file_list.append(os.path.join(folderpath, file))

    for i in pickle_file_list: 
          objectname=str(os.path.basename(i)).replace(".pkl","")
          objects={objectname:""}
          globals()[objectname]=pickle.load(open(str(i), "rb" ) )
      # First import preprocessor function to session from preprocessor.py
    exec(open(os.path.join(folderpath,'preprocessor.py')).read(),globals())
    return preprocessor

def get_runtimedata(runtimedata_s3_filename="runtime_data.json"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id"+"/"+runtimedata_s3_filename)
    runtime_data = json.load(obj.get()['Body'])

    return runtime_data


runtime_data = get_runtimedata(runtimedata_s3_filename="runtime_data.json")

preprocessor_type = runtime_data["runtime_preprocessor"]

runtime_model = runtime_data["runtime_model"]["name"]

# Load model
model = get_model_onnx(runtimemodel_s3_filename='runtime_model.onnx')

# Load preprocessor
preprocessor = get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip")

def predict(event, model, preprocessor):
    body = event["body"]
    import six
    if isinstance(event["body"], six.string_types):
        body = json.loads(event["body"])
        print(body["data"])
        bodydata = pd.DataFrame.from_dict(body["data"])
    else:
        print(body["data"])
        bodydata = pd.DataFrame.from_dict(body["data"])
        print(bodydata)

    try:
      input_data = preprocessor(bodydata)
    except:
      input_data = preprocessor(bodydata).astype(np.float32).toarray()

    # generate prediction using preprocessed input data
    input_name = model.get_inputs()[0].name

    res = model.run(None,  {input_name: input_data})

    return res[0][0].tolist()


def handler(event, context):
    result = predict(event, model, preprocessor)
    return {"statusCode": 200,
            "headers": {
                "Access-Control-Allow-Origin": "*",
                "Access-Control-Allow-Credentials": True,
                "Allow": "GET, OPTIONS, POST",
                "Access-Control-Allow-Methods": "GET, OPTIONS, POST",
                "Access-Control-Allow-Headers": "*"
            },
            "body": json.dumps(result)}
