#!/usr/bin/env python

from typing import Union, Set


def ids_from_file(filename: Union[None, str]) -> Set[int]:
    """ Reads IDs from a file line by line and return them as a set of integers.
    :param filename: Path to file containing IDs
    :return: set of integer IDs
    """
    listed_ids = set()
    if not filename:
        return listed_ids

    with open(filename, 'r') as f:
        for sof_id in f.read().split():
            listed_ids.add(int(sof_id))
    return listed_ids


def main():
    import argparse
    import tensorflow_datasets as tfds
    from sof_utils.export import export_images
    import sys

    parser = argparse.ArgumentParser(description='Exports the SOF_hip dataset as png images')
    parser.add_argument('target_path', type=str,
                        help='Path to place the images in.')
    parser.add_argument('-V', '--version', action='store_true',
                        help="Print the version string")
    parser.add_argument('--data_dir', '-d', type=str, default=None,
                        help='Path to the TFDS dataset.')
    parser.add_argument('--configuration', '-c', type=str, default=None,
                        help='Dataset configuration.')
    parser.add_argument('--format', '-t', type=str, choices=['png', 'jpeg'], default='png',
                        help="Output image format, default to png")
    parser.add_argument('--width', '-w', type=int, default=None,
                        help='Target width of the exported images')
    parser.add_argument('--height', '-g', type=int, default=None,
                        help='Target height of the exported images')
    parser.add_argument('--split_lr', '-s', action='store_true',
                        help='Split radiographs vertically into a left and a right half. The exported files will get '
                             'a "L" or "R" postfix, depending on the side of the hip that is shown. Note that the '
                             'left hip is usually shown on the right hand side of the radiograph. If the (possibly '
                             'downscaled) image does not have an even width, it is padded with one additional volumn '
                             'to the right.')
    parser.add_argument('--flip_lr', '-p', action='store_true',
                        help='Flips the images of the left hip so that they look like a right hip. Can only be used '
                             'in combination with --flip_lr')
    parser.add_argument('--visits', '-v', type=str, default=None,
                        help='Comma separated list of visits to export. If this option is omitted, all visits are '
                             'exported.')
    parser.add_argument('--include', type=str, default=None,
                        help='Only include the IDs listed in given file. The IDs should be listed one ID per line.')
    parser.add_argument('--exclude', type=str, default=None,
                        help='Exclude the IDs listed in given file. The IDs should be listed one ID per line.')
    parser.add_argument('--max-group-size', type=int, default=None,
                        help='Splits the exported images into groups. The groups will contain at most the specific '
                             'number of files. Cannot be used together width --num-groups')
    parser.add_argument('--num-groups', type=int, default=None,
                        help='Splits the exported images into the specified number of groups. Cannot be used together '
                             'with --max-group-size.')
    parser.add_argument('--randomized-groups', action='store_true',
                        help='If grouping is used (either with --max-group-size or with --num-groups), the assignment '
                             'of files to groups will be randomized.')

    args = parser.parse_args()

    if args.version:
        import sof_utils
        print(sof_utils.__version__)
        exit(0)

    if args.flip_lr and not args.split_lr:
        print("Error: --flip_lr can only be used in combination with --split_lr.\n", file=sys.stderr)
        parser.print_usage()
        exit(1)

    if args.num_groups and args.max_group_size:
        print("Error: specifying both --max-group-size and --num-groups is not allowed!", file=sys.stderr)
        exit(1)

    # Split selected visits of given
    visits = [int(visit) for visit in args.visits.split(',') if visit] if args.visits else []

    ds_name = 'SOF_hip' if not args.configuration else f"SOF_hip/{args.configuration}"
    ds = tfds.load(ds_name, split='train', data_dir=args.data_dir if args.data_dir else None)

    target_size = (args.height, args.width)

    included_ids = ids_from_file(args.include)
    excluded_ids = ids_from_file(args.exclude)

    export_images(ds, args.target_path, format=args.format, downsample_to=target_size, split_lr=args.split_lr,
                  flip_lr=args.flip_lr, visits=visits,
                  included_ids=included_ids, excluded_ids=excluded_ids,
                  num_groups=args.num_groups,
                  max_group_size=args.max_group_size,
                  randomized_groups=args.randomized_groups)


if __name__ == '__main__':
    main()
