#!python

import os
from os.path import dirname, join, exists, isdir
import sys
import time
import logging
import argparse

import openslide
import numpy as np

import histoprep as hp
from histoprep.helpers._utils import remove_extension, format_seconds
from histoprep._czi_reader import OpenSlideCzi


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ALLOWED = [
    'mrxs',
    'tiff',
    'svs',
    'tif',
    'ndpi',
    'vms',
    'vmu',
    'scn',
    'svslide',
    'bif',
    'czi',
]
TITLE = """ python3 HistoPrep {step} {arguments}

█  █  █  ██  ███  ███  ███  ███  ███ ███
█  █  █ █     █  █   █ █  █ █  █ █   █  █
████  █  ██   █  █   █ ███  ███  ██  ███
█  █  █    █  █  █   █ █    █  █ █   █
█  █  █  ██   █   ███  █    █  █ ███ █

             by Jopo666 (2021)
"""


def get_arguments():
    parser = argparse.ArgumentParser(
        usage=TITLE,
    )
    subparsers = parser.add_subparsers(
        title='Select one of the below',
        dest='step',
        metavar='')
    cut = subparsers.add_parser('cut',
                                help='Cut tiles from large histological slides.',
                                usage='python3 HistoPrep cut input_dir output_dir width {optional arguments}')
    dearray = subparsers.add_parser('dearray',
                                    help='Dearray an tissue microarray (TMA) slide.',
                                    usage='python3 HistoPrep dearray input_dir output_dir {optional arguments}')

    ### CUT ###
    cut.add_argument('input_dir',
                     help='Path to the slide directory.')
    cut.add_argument('output_dir',
                     help="Will be created if doesn't exist.")
    cut.add_argument('width', type=int,
                     help='Tile width.')
    cut.add_argument('--overlap', type=float, default=0.0, metavar='',
                     help='Tile overlap. [Default: %(default)s]')
    cut.add_argument('--max_bg', type=float, default=0.7, metavar='',
                     help='Maximum background percentage for a tile. [Default: %(default)s]')
    cut.add_argument('--downsample', type=int, default=32, metavar='',
                     help='Thumbnail downsample. [Default: %(default)s]')
    cut.add_argument('--threshold', type=float, default=1.1, metavar='',
                     help='Threshold for background detection. [Default: %(default)s]')
    cut.add_argument('--overwrite', action='store_true',
                     help='[Default: %(default)s]')
    cut.add_argument('--image_format', default='jpeg', metavar='',
                     choices=['jpeg', 'png'], help='Image format. [Default: %(default)s]')
    cut.add_argument('--quality', type=int, default=95, metavar='',
                     help='Quality for jpeg compression. [Default: %(default)s]')

    ### DEARRAY ###
    dearray.add_argument('input_dir',
                         help='Path to the slide directory.')
    dearray.add_argument('output_dir',
                         help="Will be created if doesn't exist.")
    dearray.add_argument('--downsample', type=int, default=64, metavar='',
                         help='Thumbnail downsample. [Default: %(default)s]')
    dearray.add_argument('--threshold', type=int, default=1.1, metavar='',
                         help="Threshold multiplier for Otsu's method in tissue detection. [Default: %(default)s]")
    dearray.add_argument('--min_area', type=float, default=0.2, metavar='',
                         help='Minimum area for a spot. [Default: median_area*%(default)s]')
    dearray.add_argument('--max_area', type=float, default=None, metavar='',
                         help='Maximum area for a spot. [Default: median_area*%(default)s]')
    dearray.add_argument('--kernel_size', type=int, default=8, metavar='',
                         help='Kernel size for spot detection (give as a single integer). [Default: (%(default)s,%(default)s)]')
    dearray.add_argument('--cut', action='store_true',
                         help='Cut spots after dearraying. [Default: %(default)s]')
    dearray.add_argument('--width', type=int, default=512,
                         help='Tile width. [Default: %(default)s]')
    dearray.add_argument('--overlap', type=float, default=0.0, metavar='',
                         help='Tile overlap. [Default: %(default)s]')
    dearray.add_argument('--max_bg', type=float, default=0.7, metavar='',
                         help='Maximum background percentage for a tile. [Default: %(default)s]')
    dearray.add_argument('--overwrite', action='store_true',
                         help='Remove everything before dearraying and cutting. [Default: %(default)s]')
    dearray.add_argument('--image_format', default='jpeg', metavar='', choices=['jpeg', 'png'],
                         help='Image format (jpeg or png). [Default: %(default)s]')
    dearray.add_argument('--quality', type=int, default=95, metavar='',
                         help='Quality for jpeg compression. [Default: %(default)s]')

    args = parser.parse_args()

    # Check paths.
    if not os.path.exists(args.input_dir):
        raise IOError(f'Path {args.input_dir} not found.')

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)

    return args


def get_etc(times: list, tic: float, num_left: int):
    """Print ETC and return times with new time added.

    Args:
        times (list): List of individual times.
        tic (float): Start time.
        num_left (int): Number of iterations left.

    Returns:
        [type]: [description]
    """
    times.append(time.time()-tic)
    etc = np.mean(times)*num_left
    print(f'ETC: {format_seconds(etc)}')
    return times


def collect_paths(args):
    paths = []
    for f in os.scandir(args.input_dir):
        suffix = f.name.split('.')[-1]
        if suffix in ALLOWED:
            paths.append(f)
    if len(paths) == 0:
        logger.warning(
            'No slides found! Please check that input_dir '
            'you defined is correct!'
        )
        sys.exit()
    if not args.overwrite:
        # See which slides have been processed before.
        not_processed = []
        for f in paths:
            name = remove_extension(f.name)
            if not exists(join(args.output_dir, name, 'metadata.csv')):
                not_processed.append(f)
        if len(not_processed) == 0:
            print('All slides have been cut!')
            sys.exit()
        num_cut = len(paths)-len(not_processed)
        if num_cut != 0:
            print(f'{num_cut} slide(s) had been cut before.')
        paths = not_processed
    return paths


def check_file(file):
    try:
        if file.name.endswith('czi'):
            OpenSlideCzi(file.path)
        else:
            openslide.OpenSlide(file.path)
    except:
        logger.warning(f'Slide broken! Skipping {file.name}')
        return False
    return True


def check_downsamples(path, downsample):
    """Check if any close downsamples can be found."""
    r = openslide.OpenSlide(path)
    downsamples = [int(x) for x in r.level_downsamples]
    if downsample in downsamples:
        return downsample
    elif int(downsample*2) in downsamples:
        logger.warning(
            f'Downsample {downsample} not available, '
            f'using {int(downsample*2)}.'
        )
        return int(downsample*2)
    elif int(downsample/2) in downsamples:
        logger.warning(
            f'Downsample {downsample} not available, '
            f'using {int(downsample/2)}.'
        )
        return int(downsample/2)
    else:
        return None


def cut_tiles(args):
    # Collect all slide paths
    slides = collect_paths(args)
    # Loop through each slide and cut.
    total = str(len(slides))
    print(f'HistoPrep will process {total} slides.')
    # Initialise list of times for ETC
    times = []
    tic = None
    for i, f in enumerate(slides):
        print(f'[{str(i).rjust(len(total))}/{total}] - {f.name}', end=' - ')
        if not check_file(f):
            continue
        # Calculate ETC.
        if tic is None:
            print('ETC: ...')
        else:
            times = get_etc(times=times, tic=tic, num_left=len(slides)-i-1)
        # Start time.
        tic = time.time()
        # Check downsample.
        downsample = check_downsamples(f.path, args.downsample)
        if downsample is None:
            logger.warning(
                f'No downsample close to {args.downsample} available, '
                f'trying to generate a thumbnail image.'
            )
            downsample = args.downsample
        # Prepare Cutter.
        try:
            cutter = hp.Cutter(
                slide_path=f.path,
                width=args.width,
                overlap=args.overlap,
                threshold=args.threshold,
                downsample=downsample,
                max_background=args.max_bg,
                create_thumbnail=True,
            )
        except KeyboardInterrupt:
            logger.warning('KeyboardInterrupt detected. Shutting down.')
            sys.exit()
        except Exception as e:
            logger.warning(
                f'Something went wrong with error: "{e}"'
                f'\nSkipping slide {f.name}.'
            )
            continue

        # Cut cut cut away!
        cutter.save(
            output_dir=args.output_dir,
            overwrite=args.overwrite,
            image_format=args.image_format,
            quality=args.quality,
        )
    print(f'All {total} slides processed.')


def dearray(args):
    # Collect all slide paths.
    tma_arrays = collect_paths(args)
    # Loop through each array and cut.
    total = str(len(tma_arrays))
    print(f'HistoPrep will process {total} TMA arrays.')
    # Initialise list of times for ETC
    times = []
    for i, f in enumerate(tma_arrays):
        print(f'[{str(i).rjust(len(total))}/{total}] - {f.name}', end=' - ')
        # Prepare Dearrayer.
        if not check_file(f):
            continue
        # Calculate ETC
        if i == 0:
            print('ETC: ...')
        else:
            times = get_etc(times=times, tic=tic, num_left=len(tma_arrays)-i-1)
        # Start time.
        tic = time.time()
        # Dearray!
        dearrayer = hp.Dearrayer(
            slide_path=f.path,
            threshold=args.threshold,
            downsample=args.downsample,
            min_area_multiplier=args.min_area,
            max_area_multiplier=args.max_area,
            kernel_size=(args.kernel_size, args.kernel_size),
            create_thumbnail=True,
        )
        # Dearray away!
        dearrayer.save_spots(
            output_dir=args.output_dir,
            overwrite=args.overwrite,
            image_format=args.image_format,
            quality=args.quality,
        )
        if args.cut:
            dearrayer.save_tiles(
                width=args.width,
                overlap=args.overlap,
                max_background=args.max_bg,
                overwrite=args.overwrite,
                image_format=args.image_format,
                quality=args.quality,
            )
    print(f'All {total} TMA arrays processed.')


if __name__ == '__main__':
    args = get_arguments()

    if args.step == 'cut':
        cut_tiles(args)

    elif args.step == 'dearray':
        dearray(args)

    else:
        raise NotImplemented(
            "I don't know how you did that, but that's not allowed."
        )
