#!/usr/bin/env python

import argparse
import functools
import os
import sys
import time
from typing import Tuple

import histoprep
from histoprep.helpers import multiprocess_loop
from histoprep.helpers._files import remove_extension
from histoprep.helpers._time import format_seconds
from histoprep.helpers._verbose import verbose_fn

TITLE = """HistoPrep input_dir output_dir width {optional arguments}

██╗  ██╗██╗███████╗████████╗ ██████╗ ██████╗ ██████╗ ███████╗██████╗
██║  ██║██║██╔════╝╚══██╔══╝██╔═══██╗██╔══██╗██╔══██╗██╔════╝██╔══██╗
███████║██║███████╗   ██║   ██║   ██║██████╔╝██████╔╝█████╗  ██████╔╝
██╔══██║██║╚════██║   ██║   ██║   ██║██╔═══╝ ██╔══██╗██╔══╝  ██╔═══╝
██║  ██║██║███████║   ██║   ╚██████╔╝██║     ██║  ██║███████╗██║
╚═╝  ╚═╝╚═╝╚══════╝   ╚═╝    ╚═════╝ ╚═╝     ╚═╝  ╚═╝╚══════╝╚═╝
                        by Jopo666 (2022)
"""


def get_arguments():
    parser = argparse.ArgumentParser(
        usage=TITLE,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("input_dir", help="Directory with slides.")
    parser.add_argument("output_dir", help="For processed slides.")
    parser.add_argument("width", type=int, help="Width of the tile.")
    parser.add_argument(
        "--overlap",
        type=float,
        default=0.1,
        metavar="",
        help="Overlap between neighbouring tiles.",
    )
    parser.add_argument(
        "--max_background",
        type=float,
        default=0.75,
        metavar="",
        help="Maximum allowed background per tile.",
    )
    parser.add_argument(
        "--height",
        type=int,
        metavar="",
        default=None,
        help="Height of the tile. If None, set to width.",
    )
    parser.add_argument(
        "--level",
        type=int,
        metavar="",
        default=0,
        help="Slide level for reading tile regions.",
    )
    parser.add_argument(
        "--ext",
        nargs="+",
        default=None,
        metavar="",
        help="File extensions to load. If not set, uses all readable extensions.",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=os.cpu_count(),
        metavar="",
        help="Number of image saving worker processes.",
    )
    parser.add_argument(
        "--threshold",
        type=int,
        default=None,
        metavar="",
        help="Threshold for tissue detection. If None, set with Otsu's method.",
    )
    parser.add_argument(
        "--threshold_multiplier",
        type=float,
        default=1.05,
        metavar="",
        help=(
            "Multiply Otsu's threshold with this value. Ignored if threshold is set."
        ),
    )
    parser.add_argument(
        "--max_dimension",
        type=int,
        default=2**14,
        metavar="",
        help="Maximum dimension for the thumbnail.",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="Removes everything in output folder before saving images.",
    )
    parser.add_argument(
        "--image_format",
        default="jpeg",
        metavar="",
        choices=["jpeg", "png"],
        help="Image format.",
    )
    parser.add_argument(
        "--quality",
        type=int,
        default=95,
        metavar="",
        help="Quality for jpeg compression.",
    )
    parser.add_argument(
        "--depth",
        type=int,
        default=0,
        metavar="",
        help="Depth for recursively finding slide images from the input_dir.",
    )

    args = parser.parse_args()

    # Check paths.
    if not os.path.exists(args.input_dir):
        raise FileNotFoundError(f"Path {args.input_dir} not found.")

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)
    elif not os.path.isdir(args.output_dir):
        raise FileExistsError("Output directory exists but it is not a directory.")
    return args


def find_files(parent_dir: str, extensions: Tuple[str], depth: int = 0):
    """Find all files matching the passed extensions."""
    extensions = tuple(x if x.startswith(".") else "." + x for x in extensions)
    return _list_files(parent_dir, extensions, 0, depth)


def _list_files(dir_path: str, extensions: tuple, current_depth: int, max_depth: int):
    """Helper function to list files recursively."""
    file_paths = []
    for f in os.scandir(dir_path):
        if f.is_file() and f.name.endswith(extensions):
            file_paths.append(f.path)
        if f.is_dir() and (max_depth == -1 or current_depth + 1 <= max_depth):
            file_paths += _list_files(f.path, extensions, current_depth + 1, max_depth)
    return file_paths


def check_existing_paths(slide_paths: str, output_dir: str, overwrite: bool):
    new_paths = []
    exists = 0
    for path in slide_paths:
        if not os.path.exists(
            os.path.join(output_dir, remove_extension(os.path.basename(path)))
        ):
            new_paths.append(path)
        else:
            exists += 1
    if exists > 0:
        print(
            "{} {} slide paths (from {}) as they exists.".format(
                "Overwriting" if overwrite else "Skipping",
                exists if overwrite else len(slide_paths) - len(new_paths),
                len(slide_paths),
            )
        )
    return slide_paths if overwrite else new_paths


def main(args):
    # Collect paths.
    slide_paths = find_files(
        args.input_dir,
        extensions=tuple(histoprep.READABLE_FORMATS)
        if args.ext is None
        else tuple(args.ext),
        depth=args.depth,
    )
    slide_paths = check_existing_paths(slide_paths, args.output_dir, args.overwrite)
    if len(slide_paths) == 0:
        print("No readable slides found in {}.".format(args.input_dir))
        return
    else:
        print("Found {} slides in {}".format(len(slide_paths), args.input_dir))
    # Define worker fn.
    num_workers = max(1, min(args.num_workers, len(slide_paths)))
    worker_fn = functools.partial(cut_slide, args=args, total_slides=len(slide_paths))
    # Process slides.
    print(
        "Processing {} slides with {} worker processes.".format(
            len(slide_paths), num_workers
        )
    )
    for __ in multiprocess_loop(worker_fn, enumerate(slide_paths), num_workers):
        pass


def cut_slide(info: Tuple[int, str], args: argparse.Namespace, total_slides: int):
    tic = time.perf_counter()
    # Unpack.
    idx, path = info
    # Define verbose function.
    verbose = functools.partial(
        verbose_fn, desc="{}/{}".format(idx + 1, total_slides), color=False
    )
    verbose("Reading slide from {}".format(path))
    try:
        # Read slide.
        reader = histoprep.SlideReader(
            path,
            verbose=False,
            max_dimension=args.max_dimension,
            threshold=args.threshold,
            threshold_multiplier=args.threshold_multiplier,
        )
    except KeyboardInterrupt:
        print("Detected keyboard interrupt.")
        sys.exit(9)
    except Exception as e:
        verbose("Could not process slide {} due to error:\n    {}".format(path, e))
        return
    try:
        # Get coordinates.
        coordinates = reader.get_tile_coordinates(
            width=args.width,
            height=args.height,
            level=args.level,
        )
        if len(coordinates) == 0:
            verbose("Could not find tiles from {}.".format(reader.slide_name))
            return
        verbose("Saving {} tiles from {}.".format(len(coordinates), reader.slide_name))
        reader.save_tiles(
            output_dir=args.output_dir,
            coordinates=coordinates,
            level=args.level,
            overwrite=args.overwrite,
            image_format=args.image_format,
            quality=args.quality,
            num_workers=1,
            display_progress=False,
        )
    except KeyboardInterrupt:
        print("Detected keyboard interrupt.")
        sys.exit(9)
    except Exception as e:
        verbose("Could not save tiles from {} due to error: {}.".format(path, e))
        return
    verbose(
        "Slide {} processed in {}.".format(
            reader.slide_name, format_seconds(time.perf_counter() - tic)
        )
    )


if __name__ == "__main__":
    args = get_arguments()
    main(args)
