#!usr/bin/env python
# -*- coding: utf-8 -*-

import os, argparse, sys, pkg_resources
from pathlib import Path
from datetime import date

import SimpleITK as sitk

import LabelFusion
from LabelFusion.itkUtils import *
from LabelFusion.utils import *
from LabelFusion.wrapper import *

def main():
  copyrightMessage = 'Contact: software@cbica.upenn.edu\n\n' + 'This program is NOT FDA/CE approved and NOT intended for clinical use.\nCopyright (c) ' + str(date.today().year) + ' University of Pennsylvania. All rights reserved.' 
  parser = argparse.ArgumentParser(prog='LabelFusion', formatter_class=argparse.RawTextHelpFormatter, description = "Fusion of different labels together.\n\n" + copyrightMessage)
  parser.add_argument('-inputs', type=str, help = 'The absolute, comma-separated paths of labels that need to be fused', required=True)
  parser.add_argument('-classes', type=str, help = 'The expected labels; for example, for BraTS, this should be \'0,1,2,4\'; this is not needed for STAPLE|ITKVoting', default='0,1')
  parser.add_argument('-method', type=str, help = 'The method to apply; currently available: STAPLE | ITKVoting | MajorityVoting | SIMPLE', required=True)
  parser.add_argument('-output', type=str, help = 'The output file to write the results', required=True)

  args = parser.parse_args()
  
  inputs = list(args.inputs.split(',')) # list of input file paths

  if len(inputs) < 2:
    sys.exit('Cannot perform fusion for less than 2 input labels')

  for i in range(1, len(inputs)):
    if not imageSanityCheck(inputs[0], inputs[i]):
      sys.exit('There is a mismatch between the physical definitions of the input labels, please check')

  class_list = list(args.classes.split(','))
  class_list_int = [int(i) for i in class_list]

  method = args.method.lower()

  inputSegmentationImages = []
  for i in range(0, len(inputs)):
    inputSegmentationImages.append(sitk.ReadImage(inputs[i], sitk.sitkUInt8))

  fused_segmentation_image = fuse_images(inputSegmentationImages, method, class_list_int)

  # finally, write the fused image
  sitk.WriteImage(fused_segmentation_image, args.output)
  exit(0)

if __name__ == '__main__':
  main()
