#Image Classification Prediction Runtime Code

import boto3
import cv2
import os
import numpy as np
import json
import onnxruntime as rt
import base64
import imghdr
import six
from functools import partial
import os.path
from os import path
import pandas as pd


def get_model_onnx(runtimemodel_s3_filename="runtime_model.onnx"):
    s3 = boto3.resource('s3')
    obj = s3.Object("$bucket_name", "$unique_model_id" +
                    "/runtime_model.onnx")
    model = rt.InferenceSession(obj.get()['Body'].read())
    return model

def _get_pyspark_modules():
    import pyspark
    import pyspark.ml
    import re
    
    pyspark_modules = ['ml', 'ml.feature', 'ml.classification', 'ml.clustering', 'ml.regression']

    models_modules_dict = {}

    for i in pyspark_modules:
        models_list = [j for j in dir(eval('pyspark.'+i, {'pyspark': pyspark})) if callable(getattr(eval('pyspark.'+i, {'pyspark': pyspark}), j))]
        models_list = [j for j in models_list if re.match('^[A-Z]', j)]

        for k in models_list: 
            models_modules_dict[k] = 'pyspark.'+i
    
    return models_modules_dict

def pyspark_model_from_string(model_type):
    import importlib

    models_modules_dict = _get_pyspark_modules()
    module = models_modules_dict[model_type]
    model_class = getattr(importlib.import_module(module), model_type)
    return model_class

def get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip"):
    import os
    import pickle
    import tempfile
    from io import BytesIO
    from pathlib import Path
    from zipfile import ZipFile

    #create temporary folder
    temp_dir = tempfile.gettempdir()
    
    # there are some other zip files on temp_dir that might affect import process
    temp_dir = os.path.join(temp_dir, "$unique_model_id")
    os.makedirs(temp_dir, exist_ok=True)

    s3 = boto3.resource("s3")
    bucket = s3.Bucket("$bucket_name")
    
    zip_obj = s3.Object(bucket_name="$bucket_name",
                        key="$unique_model_id/runtime_preprocessor.zip")
    buffer = BytesIO(zip_obj.get()["Body"].read())
    z = ZipFile(buffer)
    # Extract all the contents of zip file in temp directory
    z.extractall(temp_dir)
    
    # Then import all pkl files you want from bucket (need to generate this list from
    # function globals
    pickle_file_list = []
    zip_file_list = []
    for file in os.listdir(temp_dir):
        if file.endswith(".pkl"):
            pickle_file_list.append(os.path.join(temp_dir, file))
        if file.endswith(".zip"):
            zip_file_list.append(os.path.join(temp_dir, file))
    
    for i in pickle_file_list: 
        objectname = str(os.path.basename(i)).replace(".pkl", "")
        objects = { objectname: "" }
        globals()[objectname] = pickle.load(open(str(i), "rb"))
    
    # Need spark session and context to instantiate model object
    # zip_file_list is only used by pyspark
    if len(zip_file_list):
        from pyspark.sql import SparkSession
    
        spark = SparkSession \
            .builder \
            .appName('Pyspark Model') \
            .getOrCreate()
    
        for i in zip_file_list:
            objectnames = str(os.path.basename(i)).replace(".zip", "").split("__")
            dir_path = i.replace(".zip", "")
            Path(dir_path).mkdir(parents=True, exist_ok=True)
              
            # Create a ZipFile Object and load module.zip in it
            with ZipFile(i, 'r') as zipObj:
                # Extract all the contents of zip file in current directory
                zipObj.extractall(dir_path)

            preprocessor_type = objectnames[0].split("_")[0]
            objectname = objectnames[1]
            preprocessor_class = pyspark_model_from_string(preprocessor_type)
            if preprocessor_type == "PipelineModel":
                print(preprocessor_class)
                preprocessor_model = preprocessor_class(stages=None)
            else:
                preprocessor_model = preprocessor_class()

            preprocessor_model = preprocessor_model.load(dir_path)
            globals()[objectname] = preprocessor_model

    # First import preprocessor function to session from preprocessor.py
    exec(open(os.path.join(temp_dir, 'preprocessor.py')).read(),globals())
    return preprocessor

def get_runtimedata(runtimedata_s3_filename="runtime_data.json"):

      s3 = boto3.resource('s3')
      obj = s3.Object("$bucket_name", "$unique_model_id"+"/"+runtimedata_s3_filename)
      runtime_data = json.load(obj.get()['Body'])

      return runtime_data


runtime_data=get_runtimedata(runtimedata_s3_filename="runtime_data.json")

preprocessor_type=runtime_data["runtime_preprocessor"]

runtime_model=runtime_data["runtime_model"]["name"]

model=get_model_onnx(runtimemodel_s3_filename='runtime_model.onnx')

# Load preprocessor

preprocessor=get_preprocessor(preprocessor_s3_filename="runtime_preprocessor.zip")


def predict(event,model,preprocessor):

  labels=$labels
  
# Load base64 encoded image stored within "data" key of event dictionary
  print(event["body"])
  body = event["body"]
  if isinstance(event["body"], six.string_types):
        body = json.loads(event["body"])
        
  bodydata=body["data"]

# Extract image file extension (e.g.-jpg, png, etc.)
  sample = base64.decodebytes(bytearray(bodydata, "utf-8"))

  for tf in imghdr.tests:
    image_file_type = tf(sample, None)
    if image_file_type:
      break;
  image_file_type=image_file_type
 
  if(image_file_type==None):
    print("This file is not an image, please submit an image base64 encoded image file.")

# Save image to local file, read into session, and preprocess image with preprocessor function

  with open("/tmp/imagetopredict."+image_file_type, "wb") as fh:
    fh.write(base64.b64decode(bodydata))

  input_data = preprocessor("/tmp/imagetopredict."+image_file_type)
  

# Generate prediction using preprocessed input data
  print("The model expects input shape:", model.get_inputs()[0].shape)
  
  input_name = model.get_inputs()[0].name

  res = model.run(None, {input_name: input_data})
 
  #extract predicted probability for all classes, extract predicted label
  
  prob = res[0]

  def predict_classes(x): # adjusted from keras github code
    if len(x.shape)==2:
        index = x.argmax(axis=-1)
        return list(map(lambda x: labels[x], index))
    else:
        return list(x)
      
  result=predict_classes(prob)

  os.remove("/tmp/imagetopredict."+image_file_type)

  # pyspark returns np.longlong which is not json serializable.
  if len(result) and type(result[0]) == np.longlong:
      result = [int(x) for x in result]

  return result

def handler(event, context):

    result = predict(event,model,preprocessor)
    return {"statusCode": 200,
    "headers": {
    "Access-Control-Allow-Origin" : "*",
    "Access-Control-Allow-Credentials": True,
    "Allow" : "GET, OPTIONS, POST",
    "Access-Control-Allow-Methods" : "GET, OPTIONS, POST",
    "Access-Control-Allow-Headers" : "*"
    },
    "body" : json.dumps(result)}
