#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import os
import argparse
import cv2 as cv
import numpy as np
import psd_tools
from psd_tools import PSDImage
from itertools import repeat
from concurrent.futures import ThreadPoolExecutor as Pool

def extract_layer(layer):
    layer.compose().save('./{}_{}_{}.png'.format(layer.name.replace(' ', '_'), layer.offset[0], layer.offset[1]))

def layer_to_opencv(layer, bbox=None):
    pil_image = layer.compose(bbox=bbox).convert('RGBA')
    open_cv_image = np.array(pil_image).copy()
    # Convert RGBA to BGRA float image
    return np.float32(cv.cvtColor(open_cv_image, cv.COLOR_RGBA2BGRA)) / 255.0

def combine(background_4C, foreground_4C, mask_1C=None, debug_depth=None):

    # split background into rgb and alpha
    background_3C = background_4C[:, :, :-1]
    background_alpha_1C = background_4C[:, :, -1]
    background_alpha_3C = cv.cvtColor(background_alpha_1C, cv.COLOR_GRAY2BGR)

    # split foreground into rgb and alpha
    foreground_3C = foreground_4C[:, :, :-1]
    foreground_alpha_1C = foreground_4C[:, :, -1]

    # select alpha mask
    mask_alpha_1C = mask_1C if mask_1C is not None else foreground_alpha_1C
    mask_alpha_3C = cv.cvtColor(mask_alpha_1C, cv.COLOR_GRAY2BGR)


    if debug_depth is not None:
        print('\n' + '\t' * debug_depth + 'combining background and foreground using mask')
        print('\t' * debug_depth + '\tbackground_4C:'.ljust(24), background_4C.shape, background_4C.dtype)
        print('\t' * debug_depth + '\tforeground_4C:'.ljust(24), foreground_4C.shape, background_4C.dtype)
        print('\t' * debug_depth + '\tmask_alpha_3C:'.ljust(24), mask_alpha_3C.shape, mask_alpha_3C.dtype)


    # Multiply the foreground with the alpha mask
    foreground_3C = cv.multiply(mask_alpha_3C, foreground_3C)

    # Multiply the background with ( 1 - alpha )
    background_3C = cv.multiply(1.0 - mask_alpha_3C, background_3C)

    # Add the masked foreground and background.
    merged_3C = cv.add(foreground_3C, background_3C)

    # add alpha channel to image and asign the correct value
    merged_4C = cv.cvtColor(merged_3C, cv.COLOR_BGR2BGRA)
    merged_4C[:, :, -1] = np.max((mask_alpha_1C, background_alpha_1C), axis=0)

    if debug_depth is not None:
        print('\t' * debug_depth + '\t-> done')

    return merged_4C

def add_layer(background_4C, foreground_4C, bbox, debug_depth=None):
    # bbox = (left, top, right, bottom)
    assert background_4C.shape[0] >= bbox[3], 'background height is to small: {} >= {}'.format(background_4C.shape[0], bbox[3])
    assert background_4C.shape[1] >= bbox[2], 'background width is to small: {} >= {}'.format(background_4C.shape[1], bbox[2])
    assert background_4C.shape[2] == 4, 'background doesnt have 4 channels: {} >= 4'.format(background_4C.shape[2])

    background_region_4C = background_4C[bbox[1]:bbox[3], bbox[0]:bbox[2]]

    merged_4C = combine(background_region_4C, foreground_4C, debug_depth=debug_depth+1 if debug_depth is not None else None)

    background_4C[bbox[1]:bbox[3], bbox[0]:bbox[2]] = merged_4C
    return background_4C

def add_screen(background_4C, screenshot_3C, mask_1C, bbox, debug_depth=None):
    # bbox = (left, top, right, bottom)
    assert background_4C.shape[0] >= bbox[3], 'background height is to small: {} >= {}'.format(background_4C.shape[0], bbox[3])
    assert background_4C.shape[1] >= bbox[2], 'background width is to small: {} >= {}'.format(background_4C.shape[1], bbox[2])
    assert background_4C.shape[2] == 4, 'background doesnt have 4 channels: {} >= 4'.format(background_4C.shape[2])

    screenshot_3C = cv.resize(screenshot_3C, (mask_1C.shape[1], mask_1C.shape[0]))

    # get region of screen
    background_region_4C = background_4C[bbox[1]:bbox[3], bbox[0]:bbox[2]]


    if debug_depth is not None:
        print('\n' + '\t' * debug_depth + 'adding screen to existing canvas')
        print('\t' * debug_depth + '\tbackground_4C:'.ljust(24), background_4C.shape, background_4C.dtype)
        print('\t' * debug_depth + '\tbackground_region_4C:'.ljust(24), background_region_4C.shape, background_region_4C.dtype)
        print('\t' * debug_depth + '\tscreenshot_3C:'.ljust(24), screenshot_3C.shape, screenshot_3C.dtype)
        print('\t' * debug_depth + '\tmask_1C:'.ljust(24), mask_1C.shape, mask_1C.dtype)


    merged_4C = combine(background_region_4C, cv.cvtColor(screenshot_3C, cv.COLOR_BGR2BGRA), mask_1C, debug_depth=debug_depth+1 if debug_depth is not None else None)
    background_4C[bbox[1]:bbox[3], bbox[0]:bbox[2]] = merged_4C


    if debug_depth is not None:
        print('\t' * debug_depth + '\t-> done')

    return background_4C

def trim_alpha(input_4C, debug_depth=None):
    alpha_1C = input_4C[:, :, -1]

    x_occurences = np.where(np.any(alpha_1C, axis=0))[0]
    y_occurences = np.where(np.any(alpha_1C, axis=1))[0]
    mask_left = x_occurences.min()
    mask_right = x_occurences.max()
    mask_top = y_occurences.min()
    mask_bottom = y_occurences.max()

    if debug_depth is not None:
        print('\n' + '\t' * debug_depth + 'trimming alpha on all sides')
        print('\t' * debug_depth + '\tinput_4C:'.ljust(24), input_4C.shape, input_4C.dtype)
        print('\t' * debug_depth + '\tnew bounding box:'.ljust(24), mask_left, mask_right, mask_top, mask_bottom)

    return input_4C[mask_top:mask_bottom, mask_left:mask_right]

def resize_and_pad(input, output_size=None, debug_depth=None):
    if not isinstance(output_size, tuple):
        return input

    (orig_height, orig_width) = input.shape[:2]
    orig_aspect_ratio = float(orig_width) / float(orig_height)
    (dest_width, dest_height) = output_size

    if dest_width is not None and dest_height is not None:
        # height and width are defined
        output = np.zeros((dest_height, dest_width, 4), dtype=np.float32)

        width_ratio = dest_width / orig_width
        height_ratio = dest_height / orig_height
        if width_ratio < height_ratio:
            # width is constraining
            resized = cv.resize(input, dsize=(0, 0), fx=width_ratio, fy=width_ratio)

            # center the image vertically and pad it with transperant pixel
            new_height = resized.shape[0]
            start_y = abs(int((new_height - dest_height) / 2))
            end_y   = start_y + new_height
            output[start_y:end_y, :] = resized
        else:
            # height is constraining
            resized = cv.resize(input, dsize=(0, 0), fx=height_ratio, fy=height_ratio)

            # center the image horizontally and pad it with transperant pixel
            new_width = resized.shape[1]
            start_x = abs(int((new_width - dest_width) / 2))
            end_x   = start_x + new_width
            output[:, start_x:end_x] = resized
        return output

    elif dest_width is not None:
        # no height given -> scale it to given width keeping aspect ratio
        width_ratio = dest_width / orig_width
        return cv.resize(input, dsize=(0, 0), fx=width_ratio, fy=width_ratio)

    elif dest_height is not None:
        # no width given -> scale it to given height keeping aspect ratio
        height_ratio = dest_height / orig_height
        return cv.resize(input, dsize=(0, 0), fx=height_ratio, fy=height_ratio)

    else:
        # width and height are both not given
        return input

def make_opaque(input_4C, color, debug_depth=None):
    assert len(color) == 3, 'this is not a color: {}'.format(color)

    pixel_value = (color[2], color[1], color[0], 1)
    background_4C = np.full(input_4C.shape, pixel_value, dtype=np.float32)
    if debug_depth is not None:
        print('\n' + '\t' * debug_depth + 'making image opaque')
        print('\t' * debug_depth + '\tinput_4C:'.ljust(24), input_4C.shape, input_4C.dtype)
        print('\t' * debug_depth + '\tcolor (bgra):'.ljust(24), pixel_value)

    return combine(background_4C, input_4C, debug_depth=debug_depth+1 if debug_depth is not None else None)

class MockingBird:

    def __init__(self, document_filepath):

        ################### OPEN DOCUMENT AND EXTRACT LAYERS ###################

        self.document = PSDImage.open(document_filepath)
        print('Content Offset'.rjust(16), self.document.offset)
        print('Content Size'.rjust(16), self.document.size)

        self.frame_layers = []
        self.placeholder_layers = []

        #screen_idx = 0
        for layer in self.document.descendants():
            if 'shadow' in layer.name.lower() or 'status bar' in layer.name.lower() or type(layer) == psd_tools.api.adjustments.SolidColorFill:
                continue

            layer_name = layer.name
            if '<' in layer_name:
                layer_name = layer_name[:layer_name.find('<')]

            if type(layer) == psd_tools.api.layers.PixelLayer:
                print('Frame'.rjust(16), layer_name.ljust(32), layer.size, layer.mask)
                self.frame_layers.append(layer)

                # TODO: maybe this is a lot faster?
                # has_pixels()
                # Returns True if the layer has associated pixels. When this is True, topil method returns PIL.Image.


            elif type(layer) == psd_tools.api.layers.SmartObjectLayer:
                print('Placeholder'.rjust(16), layer_name.ljust(32), layer.size, layer.mask)
                self.placeholder_layers.append(layer)

        self.num_placeholders = len(self.placeholder_layers)
        num_frames = len(self.frame_layers)

        assert num_frames >= 1
        assert num_frames == self.num_placeholders

        ############# RENDER ALL FRAMES AND EXTRACT MASKS FROM PSD #############

        print('rendering frames...', end='', flush=True)
        self.frame_images = []
        for frame_layer in self.frame_layers:
            frame_bbox = frame_layer.bbox
            if frame_layer.has_mask():
                frame_bbox = frame_layer.mask.bbox

            frame_bbox = list(frame_bbox)
            # bbox = (left, top, right, bottom)
            frame_bbox[0] = max(frame_bbox[0], self.document.viewbox[0])
            frame_bbox[1] = max(frame_bbox[1], self.document.viewbox[1])
            frame_bbox[2] = min(frame_bbox[2], self.document.viewbox[2])
            frame_bbox[3] = min(frame_bbox[3], self.document.viewbox[3])
            frame_bbox = tuple(frame_bbox)

            frame_image = layer_to_opencv(frame_layer, bbox=frame_bbox)

            self.frame_images.append((frame_image, frame_bbox))
        print(' -> done :-)')

        print('extracting masks...', end='', flush=True)
        self.placeholder_masks = []
        for placeholder_layer in self.placeholder_layers:
            mask_bbox = placeholder_layer.bbox
            if placeholder_layer.has_mask():
                mask_bbox = placeholder_layer.mask.bbox

            # account for offsets on viewbox of document
            # bbox = (left, top, right, bottom)
            mask_bbox = list(mask_bbox)
            mask_bbox[0] = max(mask_bbox[0], self.document.viewbox[0])
            mask_bbox[1] = max(mask_bbox[1], self.document.viewbox[1])
            mask_bbox[2] = min(mask_bbox[2], self.document.viewbox[2])
            mask_bbox[3] = min(mask_bbox[3], self.document.viewbox[3])
            mask_bbox = tuple(mask_bbox)

            placeholder_mask = layer_to_opencv(placeholder_layer, bbox=mask_bbox)[:, :, -1]
            self.placeholder_masks.append((placeholder_mask, mask_bbox))
        print(' -> done :-)')

    def embed_screenshots(self, screenshots_3C, debug=False):
        canvas_4C = np.zeros((self.document.height, self.document.width, 4), dtype=np.float32)

        if debug:
            print('\tdocument:'.ljust(24), self.document.width, self.document.height, self.document.bbox, self.document.viewbox)
            print('\tcanvas_4C:'.ljust(24), canvas_4C.shape, canvas_4C.dtype)

        for i in range(len(screenshots_3C)):

            # add background frame
            (frame_4C, frame_bbox) = self.frame_images[i]

            if debug:
                print('\tframe_4C:'.ljust(24), frame_4C.shape, frame_4C.dtype)
                print('\tframe_bbox:'.ljust(24), frame_bbox)

            canvas_4C = add_layer(canvas_4C, frame_4C, frame_bbox, debug_depth=1 if debug else None)

            # replace placeholder with screenshot
            (mask_1C, mask_bbox) = self.placeholder_masks[i]

            if debug:
                print('\tmask_1C:'.ljust(24), mask_1C.shape, mask_1C.dtype)
                print('\tmask_bbox:'.ljust(24), mask_bbox)

            canvas_4C = add_screen(canvas_4C, screenshots_3C[i], mask_1C, mask_bbox, debug_depth=1 if debug else None)

        return canvas_4C

    def create_mockup(self, screenshot_filepaths, output_directory, output_size=None, should_trim_alpha=False, background_color=None, debug=False):

        # read screenshots into memory
        screenshots_3C = list(map(lambda filepath: np.float32(cv.imread(filepath)) / 255.0, screenshot_filepaths))

        # embed the screenshots in those device frames
        output_4C = self.embed_screenshots(screenshots_3C, debug=debug)

        # trim alpha from the sides
        if should_trim_alpha:
            output_4C = trim_alpha(output_4C, debug_depth=1 if debug else None)

        # resize to desired format
        output_4C = resize_and_pad(output_4C, output_size, debug_depth=1 if debug else None)

        if background_color is not None:
            output_4C = make_opaque(output_4C, background_color, debug_depth=1 if debug else None)

        # convert back from float to byte
        output_4C = np.uint8(output_4C * 255.0)

        # write with new filename
        screenshot_filename = os.path.basename(screenshot_filepaths[0])
        output_filepath = os.path.join(
            output_directory,
            screenshot_filename\
                .replace('.PNG',  '.png')\
                .replace('.jpg',  '.png')\
                .replace('.JPG',  '.png')\
                .replace('.jpeg', '.png')\
                .replace('.JPEG', '.png')\
        )
        cv.imwrite(output_filepath, output_4C)
        print('mocked', output_filepath, flush=True)

    def mock(self, screenshot_directories, output_directory, output_size=None, should_trim_alpha=False, background_color=None, num_workers=os.cpu_count(), debug=False):
        assert self.num_placeholders >= len(screenshot_directories)

        ############## RETRIEVE LIST OF ALL AVAILABLE SCREENSHOTS ##############

        screenshot_filepaths_per_directory = []
        for directory in screenshot_directories:
            filenames = filter(lambda filename: '.png' in filename.lower() or '.jpg' in filename.lower() or '.jpeg' in filename.lower(), os.listdir(directory))
            screenshot_filepaths_per_directory.append(list(map(lambda x: os.path.join(directory, x), filenames)))

        num_screenshots = len(screenshot_filepaths_per_directory[0])
        for i in range(1, len(screenshot_filepaths_per_directory)):
            assert num_screenshots == len(screenshot_filepaths_per_directory[i])

        print('found', num_screenshots, 'screenshots')
        screenshot_filepaths_per_directory = np.array(screenshot_filepaths_per_directory)
        screenshot_filepaths_per_feature = np.transpose(screenshot_filepaths_per_directory)

        ################# ENSURE OUTPUT DIRECTORY IS AVAILABLE #################

        if not os.path.exists(output_directory):
            os.makedirs(output_directory)

        arguments = [
            self.create_mockup,
            screenshot_filepaths_per_feature,
            repeat(output_directory),
            repeat(output_size),
            repeat(should_trim_alpha),
            repeat(background_color),
            repeat(debug)
        ]

        if num_workers > 1:
            print('using multithreading with up to {} threads.'.format(num_workers))
            with Pool(max_workers=num_workers) as pool:
                results = pool.map(*arguments)
                for result in results:
                    # this loop lets those exceptions fire.
                    pass
        else:
            print('multithreading disabled')
            list(map(*arguments))




def main():
    parser = argparse.ArgumentParser(description='Combine screenshot and device frame into mockup')
    parser.add_argument('--frame', default='./frame.psd', help='device frame file (default: ./frame.psd)')
    parser.add_argument('--screenshots', required=True, nargs='+', metavar='DIR', help='directories containing the screenshots to use, sorted from background to foreground')
    parser.add_argument('--output', default='./mockups/', metavar='DIR', help='directory where the finished mockups are placed (default: ./mockups/)')
    parser.add_argument('--threads', type=int, default=os.cpu_count(), metavar='COUNT', help='maximum number of threads to use simultaneously (default: cpu count of local machine)')
    parser.add_argument('--width', type=int, default=None, metavar='PIXEL', help='width of the output image')
    parser.add_argument('--height', type=int, default=None, metavar='PIXEL', help='height of the output image')

    parser.add_argument('--trimalpha', action='store_true', help='flag, specifiyng whether the image should be trimmed to opaque pixels')
    parser.add_argument('--bgcolor', type=float, default=None, nargs=3, metavar='f', help='background color specified as RGB fractions: 0 to 1. Specifiyng this option results in a totally opaque output image.')

    parser.add_argument('--verbose', action='store_true', help='flag, specifiyng whether debug output is enabled')

    args = parser.parse_args(None if sys.argv[1:] else ['-h'])

    bird = MockingBird(args.frame)
    bird.mock(
        args.screenshots,
        args.output,
        output_size=(args.width, args.height),
        should_trim_alpha=args.trimalpha,
        background_color=args.bgcolor,
        num_workers=args.threads,
        debug=args.verbose
    )

if __name__ == '__main__':
    main()
