#!/usr/bin/env python3

import argparse

desc = ('Convert a nii image to pictures of slices. If the image and the label '
        'image are both given, their overlay is converted')
formatter = argparse.ArgumentDefaultsHelpFormatter
parser = argparse.ArgumentParser(description=desc, formatter_class=formatter)

parser.add_argument('-i', '--image', required=True,
                    help='The image to convert.')
parser.add_argument('-l', '--label-image',
                    help='The corresponding label image.')

help = ('The colormap to load. It should be .txt or .npy. See '
        'segviz.load_colors for more details of the file format. It only has '
        'affect when using the -l option to show the overlay.')
parser.add_argument('-c', '--colormap', default=None, required=False, help=help)

parser.add_argument('-a', '--alpha', default=1.0, type=float,
                    help='The transparency of the label image.')
parser.add_argument('-v', '--view', choices=['axial', 'coronal', 'sagittal'],
                    default='axial', help='The view to convert into.')

help = ('Min cutoff percentage of image intensities. Any intensties less than '
        'or equal to this value is set to 0. If None, use the minimum '
        'of the image.')
parser.add_argument('-m', '--vmin', type=float, default=None, help=help)

help = ('Max cutoff percentage of image intensities. Any intensties greater '
        'than or equal to this value is set to 255. If None, use the maximum'
        'of the image.')
parser.add_argument('-M', '--vmax', type=float, default=None, help=help)

help = ('By default, the value of a label is directly the index of a color; in '
        'case the colors is only stored in the order of the ascent of the '
        'label values (for example, labels are 2, 5, 10, but there are only '
        'three colors, we need to convert 2, 5, 10 to 0, 1, 2), use this '
        'option to convert the colors array so that (2, 5, 10) rows of the new '
        'array has the (0, 1, 2) rows of the original colors.')
parser.add_argument('-r', '--reindex-colors', action='store_true', help=help)

parser.add_argument('-o', '--output', required=True, help='The output folder.')
parser.add_argument('-e', '--show-only-edge', action='store_true',
                    help='Only show the edges of each label region.')
parser.add_argument('-d', '--edge-width', type=int, default=1,
                    help='Edge width of the label map if -e is specified.')
parser.add_argument('-q', '--quantile', default=False, action='store_true',
                    help='Use quantiles to automatically scale the image.')

help = ('Use affine transformation into world coordination. Otherwise, the'
        'orientation is roughly inferenced and requires the affine matrix read '
        'from the nifti is close to an permutation matrix (voxel size are '
        'almost isotropic and not very oblique.)')
parser.add_argument('-A', '--use-affine', action='store_true', help=help)

help = ('Print the range of slices containing the label image. '
        'Note that the stop is the index of the last slice + 1.')
parser.add_argument('-s', '--show-slice-range', action='store_true', help=help)

help = 'Only show the slices with non-background labels.'
parser.add_argument('-S', '--only-show-label-slices',
                    action='store_true', help=help)
parser.add_argument('-t', '--slice-step', default=1, type=int,
                    help='Show every this number of slices.')

help = 'The number of extra non-labeled slices to show when -S is used.'
parser.add_argument('-E', '--num-extra-slices', type=int, default=0, help=help)

help = 'Concatenate slices into a single .png file as a row/column.' 
parser.add_argument('-C', '--concat', action='store_true', help=help)
parser.add_argument('-p', '--padding-size', default=0, type=int,
                    help='The size of padding when use -C.')
parser.add_argument('-P', '--padding-color',
                    type=int, nargs=4, default=[0, 0, 0, 255],
                    help='The padding color (0..255 RGBA) when use -C.')
parser.add_argument('-R', '--concat-as-row', action='store_true',
                    help='Show the slices as a row when use -C.')
parser.add_argument('-n', '--slice-range', type=int, default=None, nargs=2,
                    help='The slice range to show.')

args = parser.parse_args()


import sys
import nibabel as nib
import numpy as np
from PIL import Image
from pathlib import Path

from improc3d import transform_to_axial
from improc3d import transform_to_coronal
from improc3d import transform_to_sagittal
from segviz.render import ImageRenderer, ImagePairRenderer, ImageEdgeRenderer
from segviz.utils import get_default_colormap, load_colors, append_alpha_column
from segviz.utils import reindex_colors, concat_pils


if args.label_image is not None:
    if args.colormap is None:
        colors = get_default_colormap()
        print('--colors is not specified. Use default colormap instead.',
              file=sys.stderr)
    else:
        colors = load_colors(args.colormap)
    if colors.shape[1] == 3:
        colors = append_alpha_column(colors)

if args.view == 'axial':
    transform = transform_to_axial
elif args.view == 'coronal':
    transform = transform_to_coronal
elif args.view == 'sagittal':
    transform = transform_to_sagittal
coarse = not args.use_affine

obj = nib.load(args.image)
image = obj.get_fdata(dtype=np.float32)
affine = obj.affine
image = transform(image, affine, order=1, coarse=coarse)

if args.label_image is None:
    renderer = ImageRenderer(image)

else:
    label = nib.load(args.label_image).get_fdata(dtype=np.float32)
    label = np.round(label).astype(int)
    label = transform(label, affine, order=0, coarse=coarse)

    labs = np.unique(label)
    if colors.shape[0] < np.max(labs) + 1 and args.reindex_colors == False:
        print('The number of colors is less than the number of labels.',
              'Set --reindex-colors to True.', file=sys.stderr)
        args.reindex_colors = True
    if args.reindex_colors:
        colors = reindex_colors(colors, labs)

    if args.show_only_edge:
        renderer = ImageEdgeRenderer(image, label, colors, alpha=args.alpha,
                                     edge_width=args.edge_width)
    else:
        renderer = ImagePairRenderer(image, label, colors, alpha=args.alpha)

if args.quantile:
    renderer.automatic_rescale()
else:
    renderer.rescale_intensity(args.vmin, args.vmax)

if args.show_slice_range:
    if args.label_image is None:
        print('Label image is not specified. Cannot print label range.',
              file=sys.stderr)
    else:
        start, stop = renderer.get_label_range()
        start = max(start - args.num_extra_slices, 0)
        stop = min(stop + args.num_extra_slices, len(renderer))
        print(start, stop)

num_slices = len(renderer)
num_digits = len(str(num_slices))
if args.slice_range is not None:
    start = max(args.slice_range[0], 0)
    stop = min(args.slice_range[1], len(renderer))
    slice_range = range(start, stop, args.slice_step)
else:
    slice_range = range(0, num_slices, args.slice_step)

if args.only_show_label_slices:
    if args.slice_range is not None:
        print('Show the labeled slices. Override -n.', file=sys.stderr)
    if args.label_image is None:
        print('Label image is not specified. Cannot show label slices.',
              file=sys.stderr)
    else:
        start, stop = renderer.get_label_range()
        start = max(start - args.num_extra_slices, 0)
        stop = min(stop + args.num_extra_slices, len(renderer))
        slice_range = range(start, stop, args.slice_step)

args.output = Path(args.output)
if args.concat:
    pics = [[renderer[s_ind]] for s_ind in slice_range]
    if args.concat_as_row:
        pics = list(zip(*pics))
    pic = concat_pils(pics, bg_color=args.padding_color)
    args.output.parent.mkdir(exist_ok=True, parents=True)
    filename = args.output.with_suffix('.png')
    pic.save(filename)

else:
    args.output.mkdir(exist_ok=True, parents=True)
    for s_ind in slice_range:
        filename = args.output.joinpath(('%%0%dd.png' % num_digits) % s_ind)
        pic = renderer[s_ind]
        pic.save(filename)
