# video classification
import boto3
import os
import numpy as np
import json
import onnxruntime as rt
import base64
import six
from os import path

import json
import warnings
import pandas as pd

def get_model_onnx(runtimemodel_s3_filename="runtime_model.onnx"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id" +
                    "/runtime_model.onnx")
    model = rt.InferenceSession(obj.get()['Body'].read())
    return model

def get_runtimedata(runtimedata_s3_filename="runtime_data.json"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id" +
                    "/"+runtimedata_s3_filename)
    runtime_data = json.load(obj.get()['Body'])

    return runtime_data

def _get_pyspark_modules():
    import pyspark
    import pyspark.ml
    import re
    
    pyspark_modules = ['ml', 'ml.feature', 'ml.classification', 'ml.clustering', 'ml.regression']

    models_modules_dict = {}

    for i in pyspark_modules:
        models_list = [j for j in dir(eval('pyspark.'+i, {'pyspark': pyspark})) if callable(getattr(eval('pyspark.'+i, {'pyspark': pyspark}), j))]
        models_list = [j for j in models_list if re.match('^[A-Z]', j)]

        for k in models_list: 
            models_modules_dict[k] = 'pyspark.'+i
    
    return models_modules_dict

def pyspark_model_from_string(model_type):
    import importlib

    models_modules_dict = _get_pyspark_modules()
    module = models_modules_dict[model_type]
    model_class = getattr(importlib.import_module(module), model_type)
    return model_class

def get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip"):
    import os
    import pickle
    import tempfile
    from io import BytesIO
    from pathlib import Path
    from zipfile import ZipFile
    
    #create temporary folder
    temp_dir = tempfile.gettempdir()

    # there are some other zip files on temp_dir that might affect import process
    temp_dir = os.path.join(temp_dir, "$unique_model_id")
    os.makedirs(temp_dir, exist_ok=True)

    s3 = boto3.resource("s3")
    bucket = s3.Bucket("$bucket_name")
    
    zip_obj = s3.Object(bucket_name="$bucket_name",
                        key="$unique_model_id/runtime_preprocessor.zip")
    buffer = BytesIO(zip_obj.get()["Body"].read())
    z = ZipFile(buffer)
    # Extract all the contents of zip file in temp directory
    z.extractall(temp_dir)
    
    # Then import all pkl files you want from bucket (need to generate this list from
    # function globals
    pickle_file_list = []
    zip_file_list = []
    for file in os.listdir(temp_dir):
        if file.endswith(".pkl"):
            pickle_file_list.append(os.path.join(temp_dir, file))
        if file.endswith(".zip"):
            zip_file_list.append(os.path.join(temp_dir, file))
    
    for i in pickle_file_list: 
        objectname = str(os.path.basename(i)).replace(".pkl", "")
        objects = { objectname: "" }
        globals()[objectname] = pickle.load(open(str(i), "rb"))
    
    # Need spark session and context to instantiate model object
    # zip_file_list is only used by pyspark
    if len(zip_file_list):
        from pyspark.sql import SparkSession
    
        spark = SparkSession \
            .builder \
            .appName('Pyspark Model') \
            .getOrCreate()
    
        for i in zip_file_list:
            objectnames = str(os.path.basename(i)).replace(".zip", "").split("__")
            dir_path = i.replace(".zip", "")
            Path(dir_path).mkdir(parents=True, exist_ok=True)
              
            # Create a ZipFile Object and load module.zip in it
            with ZipFile(i, 'r') as zipObj:
                # Extract all the contents of zip file in current directory
                zipObj.extractall(dir_path)

            preprocessor_type = objectnames[0].split("_")[0]
            objectname = objectnames[1]
            preprocessor_class = pyspark_model_from_string(preprocessor_type)
            if preprocessor_type == "PipelineModel":
                print(preprocessor_class)
                preprocessor_model = preprocessor_class(stages=None)
            else:
                preprocessor_model = preprocessor_class()

            preprocessor_model = preprocessor_model.load(dir_path)
            globals()[objectname] = preprocessor_model

    # First import preprocessor function to session from preprocessor.py
    exec(open(os.path.join(temp_dir, 'preprocessor.py')).read(),globals())
    return preprocessor


def predict(event, model, preprocessor):

    # Load base64 encoded /. stored within "data" key of event dictionary
    # print(event["body"])
    body = event["body"]
    if isinstance(event["body"], six.string_types):
        body = json.loads(event["body"])
# only supporting wav extension as of now
    extension = body['extension']
    bodydata = body["data"]

    sample = base64.decodebytes(bytearray(bodydata, "utf-8"))

# Save video to local file, read into session, and preprocess image with preprocessor function
    with open("/tmp/videotopredict."+extension, "wb") as fh:
        fh.write(base64.b64decode(bodydata))

    input_data = preprocessor(f"/tmp/videotopredict.{extension}")

# Generate prediction using preprocessed input data
    print("The model expects input shape:", model.get_inputs()[0].shape)

    input_name = model.get_inputs()[0].name
    input_data = np.float32(input_data)

    res = model.run(None, {input_name: input_data})

    # extract predicted probability for all classes, extract predicted label

    prob = res[0]

    def predict_classes(x):
        proba = x
        if proba.shape[-1] > 1:
            return proba.argmax(axis=-1)
        else:
            return (proba > 0.5).astype("int32")

    prediction_index = predict_classes(prob)

    labels = $labels
    result = list(map(lambda x: labels[x], prediction_index))

    os.remove("/tmp/videotopredict."+extension)

    # pyspark returns np.longlong which is not json serializable.
    if len(result) and type(result[0]) == np.longlong:
        result = [int(x) for x in result]

    return result


runtime_data = get_runtimedata(runtimedata_s3_filename="runtime_data.json")

runtime_model = runtime_data["runtime_model"]["name"]

model = get_model_onnx(runtimemodel_s3_filename='runtime_model.onnx')

# Load preprocessor
preprocessor = get_preprocessor(
    preprocessor_s3_filename="runtime_preprocessor.zip")


def handler(event, context):
    result = predict(event, model, preprocessor)
    return {"statusCode": 200,
            "headers": {
                "Access-Control-Allow-Origin": "*",
                "Access-Control-Allow-Credentials": True,
                "Allow": "GET, OPTIONS, POST",
                "Access-Control-Allow-Methods": "GET, OPTIONS, POST",
                "Access-Control-Allow-Headers": "*"
            },
            "body": json.dumps(result)}
