#! /usr/bin/env python3
# ------------------------------------------------------------------------------
# Ascheck computes the fraction of asymmetric pixels in an image of an object
# Copyright (C) 2017  Laurie Hutchence & Stijn Debackere
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
# ------------------------------------------------------------------------------
from __future__ import print_function
from ascheck import Image
import cv2
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from tkinter import messagebox as mb
import tkinter.filedialog as filedialog
import tkinter.simpledialog as simpledialog


def main():
    # Ask user for input directory
    read_dir = filedialog.askdirectory()
    # exit code if no directory selected
    # don't want to accidentally write to /
    if read_dir == "":
        sys.exit("No directory selected")
    else:
        files = glob.glob(read_dir + "/*.*")
        files.sort()

        asym = mb.askyesno(title="Options", message="Compute asymmetry?")
        sl = mb.askyesno(title="Options", message="Compute slices?")

        if asym:
            diagnostics = mb.askyesno(
                title="Options", message="Save diagnostic images?"
            )
            bw_dir = read_dir + "/bw/"
            # create output directory
            if not os.path.exists(bw_dir):
                os.makedirs(bw_dir)

            info = np.empty((1, 3), dtype=object)
            for idx, f in enumerate(files):
                try:
                    img = Image(f, save_dir=bw_dir)
                except TypeError:
                    # if not an image file, go to next loop
                    continue

                img_thresh = img.threshold_image(img.bw, open_close=False)
                img_final = img.open_and_close_image(
                    img.pad_image(img_thresh, size=50), iterations=4
                )
                contour = img.max_closed_contour(img_final)

                img_contour = img.fill_contour(img_final, contour, 0)
                final, contour_centered = img.center_on_contour(
                    img_contour, contour, return_contour=True
                )

                A, diff = img.calculate_asymmetry(final.astype(np.uint8))

                # flag image as suspicious if size is abnormally small
                # or if asymmetry is larger than 1
                if (
                    final.shape[0] < img.bw.shape[0] / 6.0
                    or final.shape[1] < img.bw.shape[1] / 6.0
                ) or A >= 1.0:
                    flag = "suspicious"
                else:
                    flag = "OK"

                # get extra diagnostics to check performance
                if diagnostics:
                    img_filled = img.fill_contour(
                        img.pad_image(img.bw, size=50), contour
                    )
                    img_contour = img.visualize_contour(
                        contour_centered, size=1000
                    )
                    prefix = img.save_dir + img.name
                    # the black and white image needs to be inverted for the
                    # sum of the asymmetric pixels
                    cv2.imwrite(f"{prefix}_bw.png", ~final)
                    cv2.imwrite(f"{prefix}_asym.png", diff)
                    cv2.imwrite(f"{prefix}_filled{img.ext}", img_filled)
                    cv2.imwrite(f"{prefix}_contour.png", img_contour)

                # save info to array
                info = np.vstack([info, [f, A, flag]])

            info = info[1:]
            np.savetxt(
                img.save_dir + "ascheck_results.txt",
                info.astype(str),
                fmt="%s",
                comments="#",
                header="filename asymmetry_index flag",
            )

        if sl:
            n_slices = simpledialog.askinteger(
                "Pick", "Number of slices", initialvalue=20, minvalue=5, maxvalue=100
            )
            diagnostics = mb.askyesno(
                title="Options", message="Save diagnostic images?"
            )
            intervals = np.linspace(0, 1, n_slices)

            sl_dir = read_dir + "/slices/"
            # create output directory
            if not os.path.exists(sl_dir):
                os.makedirs(sl_dir)

            for idx, f in enumerate(files):
                try:
                    img = Image(f, save_dir=sl_dir)
                except TypeError:
                    # if not an image file, go to next loop
                    continue
                img_thresh = img.threshold_image(img.bw, open_close=False)
                img_final = img.open_and_close_image(
                    img.pad_image(img_thresh, size=50), iterations=4
                )
                contour = img.max_closed_contour(img_final)
                img_contour = img.fill_contour(img_final, contour, 0)

                img.slice_intervals(img_contour, contour, intervals)

                if diagnostics:
                    img_color = img.fill_contour(
                        img.pad_image(img.bw, size=50), contour
                    )
                    c, crds, c_n, crds_n = img.slice_intervals(
                        img_contour, contour, intervals, save=False
                    )

                    plt.imshow(img_color)
                    plt.plot(c[:, 0], c[:, 1], lw=0, marker="x", markersize=10)
                    plt.plot(crds[:, 0], crds[:, 1], lw=0, marker="o", markersize=10)
                    plt.savefig(
                        img.save_dir + img.name + "_slices.png",
                        dpi=200,
                        transparent=True,
                    )
                    plt.clf()


if __name__ == "__main__":
    main()
