#!/usr/bin/env python

import argparse
import imagej
import sys
import os
import threading
import time
import matplotlib
from jpype import setupGuiEnvironment, JImplements, JOverride
import traceback

parser = argparse.ArgumentParser(description='Launches Fiji from python using PyImageJ. Requires the Fiji directory as an input. Provides support for Conda Python 3 scripting in the launched Fiji.')
parser.add_argument("path_to_Fiji", help="The file path for Fiji.app directory.")
args = parser.parse_args()

print(args.path_to_Fiji)

if not os.path.exists(args.path_to_Fiji):
        raise argparse.ArgumentTypeError("{0} does not exist".format(args.path_to_Fiji))

matplotlib.use('agg')

imagej._create_jvm(args.path_to_Fiji, mode=imagej.Mode.GUI, add_legacy=True)

def create_gateway_and_UI():
    ij = imagej._create_gateway()
    class ScriptContextWriter():
        def __init__(self, std):
            self._std_default = std
            self._thread_to_context = {}
        def addScriptContext(self, thread, scriptContext):
            self._thread_to_context[thread] = scriptContext
        def removeScriptContext(self, thread):
            if thread in self._thread_to_context:
                del self._thread_to_context[thread]
        def write(self, s):
            if threading.currentThread() in self._thread_to_context:
                self._thread_to_context[threading.currentThread()].getWriter().write(imagej.sj.to_java(s))
            else:
                self._std_default.write(s)

    stdoutContextWriter = ScriptContextWriter(sys.stdout)
    sys.stdout = stdoutContextWriter

    @JImplements('org.scijava.plugins.scripting.python.PythonScriptRunner')
    class PythonScriptRunnerImpl():

        @JOverride
        def run(self, script, vars, scriptContext):
            inputs = {}
            for key in vars.keys():
                inputs[key] = vars[key]
            inputs['ij'] = ij
            inputs['sj'] = imagej.sj

            stdoutContextWriter.addScriptContext(threading.currentThread(), scriptContext)

            inputKeys = []
            for key in inputs.keys():
                inputKeys.append(key)
            inputKeys.append('__builtins__')

            try:
                exec(imagej.sj.to_python(script), inputs)
            except:
                scriptContext.getErrorWriter().write(imagej.sj.to_java(traceback.format_exc()))

            stdoutContextWriter.removeScriptContext(threading.currentThread())

            outputs = {}
            for key in inputs.keys():
                if key not in inputKeys:
                    try:
                        outputs[key] = ij.py.to_java(inputs[key])
                    except:
                        pass
                    else:
                        #outputs must be placed back in vars
                        #to ensure script parameter outputs are returned
                        vars[key] = outputs[key]

            return ij.py.to_java(outputs)

    ij.object().addObject(PythonScriptRunnerImpl())
    ij.ui().showUI()
    return ij

macos = sys.platform == 'darwin';

if macos:
    # NB: This will block the calling (main) thread forever!
    setupGuiEnvironment(lambda: create_gateway_and_UI())
else:
    # Create and show the application.
    gateway = create_gateway_and_UI()
    # We are responsible for our own blocking.
    # TODO: Poll using something better than ui().isVisible().
    while gateway.ui().isVisible():
        time.sleep(1)
