# video classification
import boto3
import os
import numpy as np
import json
import onnxruntime as rt
import base64
import six
from os import path

import json
import warnings
import numpy as np

def get_model_onnx(runtimemodel_s3_filename="runtime_model.onnx"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id" +
                    "/runtime_model.onnx")
    model = rt.InferenceSession(obj.get()['Body'].read())
    return model

def get_runtimedata(runtimedata_s3_filename="runtime_data.json"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id" +
                    "/"+runtimedata_s3_filename)
    runtime_data = json.load(obj.get()['Body'])

    return runtime_data


def get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip"):
    import pickle
    from zipfile import ZipFile
    from io import BytesIO
    import os
    s3 = boto3.resource("s3")
    bucket = s3.Bucket("$bucket_name")

    zip_obj = s3.Object(bucket_name="$bucket_name",
                        key="$unique_model_id/runtime_preprocessor.zip")
    buffer = BytesIO(zip_obj.get()["Body"].read())
    z = ZipFile(buffer)
    # Extract all the contents of zip file in current directory
    z.extractall("/tmp/")

    folderpath = os.path.dirname(os.path.abspath("/tmp/preprocessor.py"))
    file_name = os.path.basename("/tmp/preprocessor.py")

    #Then import all pkl files you want from bucket (need to generate this list from...
    # function globals)
    import os
    pickle_file_list = []
    for file in os.listdir(folderpath):
        if file.endswith(".pkl"):
            pickle_file_list.append(os.path.join(folderpath, file))

    for i in pickle_file_list:
        objectname = str(os.path.basename(i)).replace(".pkl", "")
        objects = {objectname: ""}
        globals()[objectname] = pickle.load(open(str(i), "rb"))
      # First import preprocessor function to session from preprocessor.py
    exec(open(os.path.join(folderpath, 'preprocessor.py')).read(), globals())
    return preprocessor


def predict(event, model, preprocessor):

    # Load base64 encoded /. stored within "data" key of event dictionary
    # print(event["body"])
    body = event["body"]
    if isinstance(event["body"], six.string_types):
        body = json.loads(event["body"])
# only supporting wav extension as of now
    extension = body['extension']
    bodydata = body["data"]

    sample = base64.decodebytes(bytearray(bodydata, "utf-8"))

# Save video to local file, read into session, and preprocess image with preprocessor function
    with open("/tmp/videotopredict."+extension, "wb") as fh:
        fh.write(base64.b64decode(bodydata))

    input_data = preprocessor(f"/tmp/videotopredict.{extension}")

# Generate prediction using preprocessed input data
    print("The model expects input shape:", model.get_inputs()[0].shape)

    input_name = model.get_inputs()[0].name
    input_data = np.float32(input_data)

    res = model.run(None, {input_name: input_data})

    # extract predicted probability for all classes, extract predicted label

    prob = res[0]

    def predict_classes(x):
        proba = x
        if proba.shape[-1] > 1:
            return proba.argmax(axis=-1)
        else:
            return (proba > 0.5).astype("int32")

    prediction_index = predict_classes(prob)

    labels = $labels
    result = list(map(lambda x: labels[x], prediction_index))

    os.remove("/tmp/videotopredict."+extension)

    return result


runtime_data = get_runtimedata(runtimedata_s3_filename="runtime_data.json")

runtime_model = runtime_data["runtime_model"]["name"]

model = get_model_onnx(runtimemodel_s3_filename='runtime_model.onnx')

# Load preprocessor
preprocessor = get_preprocessor(
    preprocessor_s3_filename="runtime_preprocessor.zip")


def handler(event, context):
    result = predict(event, model, preprocessor)
    return {"statusCode": 200,
            "headers": {
                "Access-Control-Allow-Origin": "*",
                "Access-Control-Allow-Credentials": True,
                "Allow": "GET, OPTIONS, POST",
                "Access-Control-Allow-Methods": "GET, OPTIONS, POST",
                "Access-Control-Allow-Headers": "*"
            },
            "body": json.dumps(result)}
