#!/usr/bin/env python

import numpy as np
import cv2
from edgeimpulse.runner import ImpulseRunner
import time

class ImageImpulseRunner(ImpulseRunner):
    def __init__(self, model_path: str):
        super(ImageImpulseRunner, self).__init__(model_path)
        self.closed = True
        self.labels = []
        self.dim = (0, 0)
        self.videoCapture = cv2.VideoCapture()
        self.isGrayscale = False

    def init(self):
        model_info = super(ImageImpulseRunner, self).init()

        width = model_info['model_parameters']['image_input_width'];
        height = model_info['model_parameters']['image_input_height'];

        if width == 0 or height == 0:
            raise Exception('Model file "' + self._model_path + '" is not suitable for image recognition')

        self.dim = (width, height)
        self.labels = model_info['model_parameters']['labels']
        self.isGrayscale =  model_info['model_parameters']['image_channel_count'] == 1
        return model_info

    def __enter__(self):
        try:
            self.videoCapture = cv2.VideoCapture(0)
            if self.videoCapture is None or not self.videoCapture.isOpened():
                raise Exception('Unable to open video device. Check your permission settings')
            self.closed = False
            return self
        except:
            raise Exception('Unable to open video device. Check your permission settings')

    def __exit__(self, type, value, traceback):
        self.videoCapture.release()
        self.closed = True

    def classify(self, data):
        return super(ImageImpulseRunner, self).classify(data)

    def classifier(self):
        while not self.closed and self.videoCapture.isOpened():
            ret, img = self.videoCapture.read()
            if ret:
                features = []

                if self.isGrayscale:
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                    resizedImg = cv2.resize(img, self.dim, interpolation = cv2.INTER_AREA)
                    pixels = np.array(resizedImg).flatten().tolist()

                    for p in pixels:
                        features.append((p << 16) + (p << 8) + p)
                else:
                    resizedImg = cv2.resize(img, self.dim, interpolation = cv2.INTER_AREA)
                    pixels = np.array(resizedImg).flatten().tolist()

                    for ix in range(0, len(pixels), 3):
                        b = pixels[ix + 0]
                        g = pixels[ix + 1]
                        r = pixels[ix + 2]
                        features.append((r << 16) + (g << 8) + b)

                res = self.classify(features)
                yield res, img
