#!/usr/bin/env python
'''
Create/Upload BIDS datatype/tasktype mapping.

Created on October, 2019

@author: Praitayini Kanakaraj, Electrical Engineering, Vanderbilt University
'''
import os
import re
import csv
import sys
import json
import logging
import argparse
from dax import XnatUtils
from datetime import datetime

DESCRIPTION = """What is the script doing :
   *Uploads BIDS datatype, tasktype, repitition time and run number mapping to XNAT project level using the \
different OPTIONS.

Examples:
   *Create a new datatype mapping for scan_type of XNAT scans:
        BIDSMapping -p PID --xnatinfo scan_type --type datatype --create /tmp/projectID_datataype.csv
   *The correct format for /tmp/projectID_datataype.csv
        scan_type,datatype
        Resting State,func
   *Create a new datatype mapping for series_description of XNAT scans:
        BIDSMapping -p PID --xnatinfo series_description --type datatype --create /tmp/projectID_datataype.csv
   *Create a new tasktype mapping for scan_type of XNAT scans:
        BIDSMapping -p PID --xnatinfo scan_type --type tasktype --create /tmp/projectID_tasktype.csv 
   *Replace tasktype mapping for scan_type of XNAT scans: (It removes the old mapping and upload the new mapping)
        BIDSMapping -p PID --xnatinfo scan_type --type tasktype --replace /tmp/projectID_tasktype.csv
   *Update tasktype mapping for scan_type of XNAT scans: (This is ONLY add new mapping rules, CANT remove rules \
     use --replace to remove and add mapping rules)
        BIDSMapping -p PID --xnatinfo scan_type --type tasktype --update /tmp/projectID_tasktype.csv
   *Create default datatype mapping for scan_type of XNAT scans: (There is no default for series_description \
     use --create)
        BIDSMapping -p PID --xnatinfo scan_type --type datatype --create_default 
   *Download the current mapping on XNAT:
        BIDSMapping -p PID --xnatinfo scan_type --type datatype --download /tmp/download.csv 
    *Download the scan_types on project on XNAT:
* csv file example
    * BIDS datatype
        scan_type,datatype
        T1W/3D/TFE,anat
        Resting State,func
    * BIDS tasktype 
        scan_type,tasktype
        Resting State,rest
    * BIDS repetition_time_sec
        scan_type,repetition_time_sec
        Resting State,2.0
    * BIDS asltype
        scan_type, asltype
        casl_m0, m0scan
        casl, asl
    * BIDS run_number
        scan_type,run_number
        Resting State,01
"""

def add_parser_args():
    """
    Method to parse all the arguments for the tool on ArgumentParser
    :return: parser object
    """
    argp = argparse.ArgumentParser(description=DESCRIPTION, usage='use "%(prog)s --help" for more information',
                                   formatter_class=argparse.RawTextHelpFormatter)
    # Connect to XNAT
    argp.add_argument('--host', dest='host', default=None,
                      help='Host for XNAT. Default: using $XNAT_HOST.')
    argp.add_argument('-u', '--username', dest='username', default=None,
                      help='Username for XNAT. Default: using $XNAT_USER.')
    argp.add_argument("-o", "--logfile", dest="logfile", default=None,
                      help="Write the display/output in a file given to this OPTIONS.")
    # Specify Project [Required]
    argp.add_argument("-p",'--project', dest="project", default=None,
                      help='Project to create/update BIDS mapping file')
    # Specify Type of BIDS Mapping [Required]
    argp.add_argument('-t','--type', dest="type", default=None,
                      help='The type of mapping either datatype, tasktype, repetition_time_sec, or asltype')
    # Specify XNAT type for mapping [Required]
    argp.add_argument('-x','--xnatinfo', dest="xnatinfo", default=None,
                      help='The type of xnat info to use for mapping either scan_type or series_description')
    # Create mapping rules at Project level
    argp.add_argument('-c','--create', dest="create", action=None,
                      help='Create the given BIDS new mapping file at project level. (EG. --create <mappingfile>.csv)\n'
                           'Default create creates the default mapping at project file. (EG. --create)\n'
                           'csvfile EG:\n'
                           'scan_type,datatype\n'
                           'T1W/3D/TFE,anat\n'
                           'Resting State,func\n')
    # Create default mapping rules at Project level (default mapping is regex on LANDMAN project)
    argp.add_argument('--cd','--create_default', dest="create_default", action='store_true',
                      help='Default create creates the default mapping at project file. (EG. --create_default)')
    # Update mapping rules (it can only add new rules) at Project level
    argp.add_argument('-a','--update', dest="update", default=None,
                      help='Update the existing BIDS mapping file at project level. (EG. --update <mappingfile>.csv)\n'
                           'This option can only add rules')
    # Replace mapping rules (it will remove previous rules and add new rules) at Project level
    argp.add_argument('-r','--replace', dest="replace",  default=None,
                      help='Replace the existing BIDS mapping file at project level. (EG. --replace <mappingfile>.csv)\n'
                           'This option can remove and add new rules')
    # Change the mapping rules to a specific time from the LOGFILE at the Project level
    argp.add_argument('-v','--revert', dest="revert", default=None,
                      help='Revert to an old mapping from a specific date/time. (EG: --revert 10-17-19-21:32:15\n'
                           'or --revert 10-17-19). Check the LOGFILE at project level for the date')
    # Download current mapping rules at the Project level
    argp.add_argument('-d','--download', dest="download",  action=None,
                      help='Downloads the current BIDS mapping file (EG: --download <foldername>)')
    # Get a list of scan_type of the Project
    argp.add_argument('--template', dest="template", action=None,
                      help='Default mapping template (EG: --template <template file>)')

    return argp

def main_display():
    """
    Main display for the tool
    """
    print('################################################################')
    print('#                         BIDSMAPPING                          #')
    print('#                                                              #')
    print('# Developed by the MASI Lab Vanderbilt University, TN, USA.    #')
    print('# If issues, please start a thread here:                       #')
    print('# https://groups.google.com/forum/#!forum/vuiis-cci            #')
    print('# Usage:                                                       #')
    print('#     Upload rules/mapping to Project level on XNAT            #')
    print('# Parameters :                                                 #')

    if OPTIONS.host:
        print('#     %*s -> %*s#' % (
            -20, 'XNAT Host', -33, get_proper_str(OPTIONS.host)))
    if OPTIONS.username:
        print('#     %*s -> %*s#' % (
            -20, 'XNAT User', -33, get_proper_str(OPTIONS.username)))
    if OPTIONS.project:
        print('#     %*s -> %*s#' % (
            -20, 'Project ID',
            -33, get_proper_str(OPTIONS.project)))
    if OPTIONS.xnatinfo:
        print('#     %*s -> %*s#' % (
            -20, 'XNAT mapping type',
            -33, get_proper_str(OPTIONS.xnatinfo, True)))
    if OPTIONS.type:
        print('#     %*s -> %*s#' % (
            -20, 'BIDS mapping type',
            -33, get_proper_str(OPTIONS.type, True)))
    if OPTIONS.create_default:
        print('#     %*s -> %*s#' % (-20, 'Create default mode', -33, 'on'))
    if OPTIONS.create:
        print('#     %*s -> %*s#' % (
            -20, 'Create mapping with',
            -33, get_proper_str(OPTIONS.create, True)))
    if OPTIONS.update:
        print('#     %*s -> %*s#' % (
            -20, 'Update mapping with',
            -33, get_proper_str(OPTIONS.update, True)))
    if OPTIONS.replace:
        print('#     %*s -> %*s#' % (
            -20, 'Replace mapping with',
            -33, get_proper_str(OPTIONS.replace, True)))
    if OPTIONS.revert:
        print('#     %*s -> %*s#' % (
            -20, 'Revert mapping to',
            -33, get_proper_str(OPTIONS.revert, True)))
    if OPTIONS.download:
        print('#     %*s -> %*s#' % (
            -20, 'Download current mapping to',
            -33, get_proper_str(OPTIONS.download, True)))
    print('################################################################')

def get_proper_str(str_option, end=False):
    """
    Method to shorten a string into the proper size for display

    :param str_option: string to shorten
    :param end: keep the end of the string visible (default beginning)
    :return: shortened string
    """
    if len(str_option) > 32:
        if end:
            return '...' + str_option[-29:]
        else:
            return str_option[:29] + '...'
    else:
        return str_option

def setup_info_logger(name, logfile):
    """
    Using logger for the executables output.
     Setting the information for the logger.

    :param name: Name of the logger
    :param logfile: log file path to write outputs
    :return: logging object
    """
    if logfile:
        handler = logging.FileHandler(logfile, 'w')
    else:
        handler = logging.StreamHandler()

    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    logger.addHandler(handler)
    return logger

def template():
    """
    Method to generate all the scan type from a project in csv
    """
    #TODO sd | default mapping | rule at project level
    sd_list = []
    template_dir = OPTIONS.template
    # Get all the unique scans in the project
    scans_list_global = XNAT.get_project_scans(project)
    for sd in scans_list_global:
        sd_list.append(sd[xnat_type])
    sd_set = set(sd_list)
    uniq_sd = [_f for _f in list(sd_set) if _f]
    # Write the list to directory given by the user
    with open(template_dir, "w+") as f:
        wr = csv.writer(f)
        for row in uniq_sd:
            wr.writerow([row])
    LOGGER.info('Template with %s is uploaded at %s' % (xnat_type, template_dir))

def default():
    """
    Function to create the default mapping at project level if
    previous mapping does not exit
    #TODO go back to default when wanted. Use findchange for logging
    """
    new_mapping = dict()
    new_mapping[project] = {}
    # Since Repetition time and Run number should be given by user, it has no default mapping
    if mapping_type == 'repetition_time_sec' or mapping_type == 'run_number' or mapping_type == 'asltype':
        LOGGER.info('%s does not have default mapping, use --create' % mapping_type)

    else:
        if mapping_type == 'datatype':
            # Get the scan type/series description from project ID and
            # use regex to create mapping for those scans and BIDS datatype
            scans_list_global = XNAT.get_project_scans(project)

            for sd in scans_list_global:
                c = re.search('T1|T2|T1W', sd[xnat_type])
                if not c == None:
                    sd_anat = sd[xnat_type]
                    new_mapping[project][sd_anat] = "anat"

            for sd in scans_list_global:
                c = re.search('rest|Resting state', sd[xnat_type], flags=re.IGNORECASE)
                if not c == None:
                    sd_func = sd['scan_type']
                    new_mapping[project][sd_func] = "func"

            for sd in scans_list_global:
                c = re.search('dwi|dti', sd[xnat_type], flags=re.IGNORECASE)
                if not c == None:
                    sd_dwi = sd['scan_type']
                    new_mapping[project][sd_dwi] = "dwi"

            for sd in scans_list_global:
                c = re.search('Field', sd[xnat_type], flags=re.IGNORECASE)
                if not c == None:
                    sd_fmap = sd['scan_type']
                    new_mapping[project][sd_fmap] = "fmap"

            for sd in scans_list_global:
                c = re.search('asl|ASL|CASL|pCASL|pASL', sd[xnat_type], flags=re.IGNORECASE)
                if not c == None:
                    sd_fmap = sd['scan_type']
                    new_mapping[project][sd_fmap] = "perf"

        if mapping_type == 'tasktype':
            scans_list_global = XNAT.get_project_scans(project)
            for sd in scans_list_global:
                c = re.search('rest', sd[xnat_type], flags=re.IGNORECASE)
                if not c == None:
                    sd_func = sd[xnat_type]
                    new_mapping[project][sd_func] = "rest"

        new_json_name = d_m_y + '_' + mapping_type + '.json'
        # If mapping is not present upload the mapping created (above) with regex, if not throw an error
        is_json_present = check_mapping_exist()
        if not is_json_present:
            XNAT.select(os.path.join(res_files, new_json_name)).put(src=json.dumps(new_mapping, indent=2))
            log_new_mapping(new_mapping)
            LOGGER.info('CREATED: Default mapping file %s is uploaded' % new_json_name)
        else:
            LOGGER.info('Datatype mapping exists at project level. You can update it using --update to update it '
                        'or --replace to replace the mapping')

def check_mapping_exist():
    """
    Method to check if json mapping exists at Project level
    """
    is_json_present = False
    for res in XNAT.select(res_files).get():
        if not res.endswith('.json'):
            continue
        else:
            is_json_present = True
    return is_json_present

def create():
    """
    Function to read the given csv file and create new mapping
    at project level
    """
    # Read the given csv data into new_mapping
    dir_new_csv = OPTIONS.create
    new_mapping = dict()
    new_mapping[project] = {}
    csv_check(dir_new_csv)
    with open(dir_new_csv) as csvFile:
        csvReader = csv.DictReader(csvFile)
        for rows in csvReader:
            sd = rows[xnat_type]
            if mapping_type == 'datatype':
                new_mapping[project][sd] = rows['datatype']
            elif mapping_type == 'tasktype':
                new_mapping[project][sd] = rows['tasktype']
            elif mapping_type == 'repetition_time_sec':
                new_mapping[project][sd] = rows['repetition_time_sec']
            elif mapping_type == 'run_number':
                new_mapping[project][sd] = rows['run_number']
            elif mapping_type == 'asltype':
                new_mapping[project][sd] = rows['asltype']
    LOGGER.info('date %s' % (date.strftime("%d-%m-%y-%H:%M:%S")))

    # If  mapping is not present upload the new_mapping from the CSV, if not throw an error
    new_json_name = d_m_y + '_' + mapping_type + '.json'
    is_json_present = check_mapping_exist()
    if not is_json_present:
        XNAT.select(os.path.join(res_files, new_json_name)).put(src=json.dumps(new_mapping, indent=2))
        log_new_mapping(new_mapping)
        LOGGER.info('CREATED: New mapping file %s is uploaded' % new_json_name)
    else:
        LOGGER.info('Mapping exists at project level. You can update it using --update to update it '
                    'or --replace to replace the mapping')


def log_new_mapping(d):
    """
    Method to log the rules (mapping) when new mapping is created at project level
    :param d: new mapping
    """
    for k, v in list(d[project].items()):
        row0 = d_m_y, " + ", k + " : " + v
        CSVWRITER.writerow(row0)


def update_or_replace():
    """
    Function to read the csv file with new rules and add these new rules
    at project level. Update can only add new rules. Replace will remove the
    previous rules and add new ones
    """
    new_json_name = d_m_y + '_' + mapping_type + '.json'
    # Get the CSV path for CSV mapping file
    if OPTIONS.update:
        src_path = OPTIONS.update
    if OPTIONS.replace:
        src_path = OPTIONS.replace
    new_mapping = dict()
    new_mapping[project] = {}
    csv_check(src_path)
    with open(src_path, 'r') as csvFile:
        csvReader = csv.DictReader(csvFile)
        for rows in csvReader:
            sd = rows[xnat_type]
            if mapping_type == 'datatype':
                new_mapping[project][sd] = rows['datatype']
            elif mapping_type == 'tasktype':
                new_mapping[project][sd] = rows['tasktype']
            elif mapping_type == 'repetition_time_sec':
                new_mapping[project][sd] = rows['repetition_time_sec']
            elif mapping_type == 'run_number':
                new_mapping[project][sd] = rows['run_number']
            elif mapping_type == 'asltype':
                new_mapping[project][sd] = rows['asltype']
    # Get the json mapping from XNAT
    if len(XNAT.select(res_files).get()) > 1:
        for res in XNAT.select(res_files).get():
            if res.endswith('.json'):
                with open(XNAT.select(os.path.join(res_files, res)).get(dest=os.path.join("/tmp",res)), "r+") as f:
                    prev_mapping = json.load(f)
                    if project in list(prev_mapping.keys()):
                        for k, v in list(new_mapping.items()):
                            if OPTIONS.update:
                                findDiff(new_mapping, prev_mapping)
                                prev_mapping[k].update(v)
                            if OPTIONS.replace:
                                findDiff_change(new_mapping, prev_mapping)
                                prev_mapping = new_mapping
                        # Dump the updated previous mapping
                        XNAT.select(os.path.join(res_files, new_json_name)).put(src=
                                                                                json.dumps(prev_mapping, indent=2),
                                                                                overwrite=True)
                        XNAT.select(os.path.join(res_files, res)).delete()
                        os.remove(os.path.join("/tmp",res))
                        LOGGER.info('UPDATED: uploaded mapping file %s' % new_json_name)
                    else:
                        LOGGER.error('ERROR: The mapping file format should have project as key')
    else:
        LOGGER.error('ERROR: No mapping at project level to update or replace')

def csv_check(src_path):
    """
    Method to check if the csv file used by the user has the correct values in columns
    """
    with open(src_path, 'r') as csvFile:
        csvReader = csv.DictReader(csvFile)
        if csvReader.fieldnames[0] in ["scan_type","series_description"] and csvReader.fieldnames[1] in \
                ["datatype", "tasktype", "repetition_time_sec", "run_number","asltype"] and csvReader.fieldnames[0] == xnat_type and csvReader.fieldnames[1] == mapping_type:
            if csvReader.fieldnames[1] == "datatype":
                for rows in csvReader:
                    # for more data types, add name below
                    if rows['datatype'] not in ["anat", "dwi", "func", "fmap","perf"]:
                        LOGGER.error("ERROR in CSV unknown datatype column. Datatype should be anat, dwi, fmap, func, or perf")
                        sys.exit()
                LOGGER.info("CSV mapping format is good")
            else:
                LOGGER.info("CSV mapping format is good")
        else:
            LOGGER.error("ERROR in CSV column headers. Check 1) if the xnat_info and column header match, 2) if the type and column header match, "
                         "3) column headers should be series_description or scan_type and datatype, tasktype, repetition_time_sec, or run_number")
            sys.exit()

def revert():
    """
    Function to go back a mapping of a previous date/time
    using the log file at project level
    """
    new_json_name = d_m_y + '_' + mapping_type + '.json'
    orig_mapping = None
    for res in XNAT.select(res_files).get():
        if res.endswith('.json'):
            res_obj = XNAT.select(os.path.join(res_files, res))
            with open(res_obj.get(), "r+") as f:
                prev_mapping = json.load(f)
                f.seek(0, 0)
                orig_mapping = json.load(f)

                with open(XNAT.select(os.path.join(res_files, 'LOGFILE.txt')).get(), "r+") as lf:
                    lines = lf.readlines()
                    lines.reverse()
                    for l in lines:
                        if not l.startswith(OPTIONS.revert):
                            mapping_all = l.split(",")
                            if mapping_all[1].strip() == '+':
                                prev_mapping[project].pop(mapping_all[2].split(':')[0].strip())
                            if mapping_all[1].strip() == '-':
                                prev_mapping[project].update({mapping_all[2].split(':')[0].
                                                             strip(): mapping_all[2].split(':')[1].strip()})
                        else:
                            break
                findDiff_change(prev_mapping, orig_mapping)
            XNAT.select(os.path.join(res_files, new_json_name)).put(src=json.dumps(prev_mapping, indent=2))
            res_obj.delete()
            LOGGER.info('REVERTED: changed mapping file %s' % new_json_name,)


def findDiff(d1, d2, path=""):
    """
    Function to record/log the changing made when adding rules
    :param d1: dict with new mapping
    :param d2: dict with old mapping
    :param path: ""
    """
    for k, v in list(d1.items()):
        if k not in list(d2.keys()):
            row1 = d_m_y, " + ", k + " : " + v
            CSVWRITER.writerow(row1)
        else:
            if type(d1[k]) is dict:
                if path == "":
                    path = k
                else:
                    path = path + "->" + k
                findDiff(d1[k], d2[k], path)
            else:
                if d1[k] != d2[k]:
                    row2 = d_m_y, " - ", k + " : " + d2[k]
                    CSVWRITER.writerow(row2)
                    row3 = d_m_y, " + ", k + " : " + d1[k]
                    CSVWRITER.writerow(row3)


def findDiff_change(d1, d2):
    """
    Function to record/log the changing made when reverting rules
    :param d1: dict with new mapping
    :param d2: dict with previous mapping
    :param path: ""
    """
    for d1_k, d1_v in list(d1.items()):
        for k, v in list(d1[d1_k].items()):
            row4 = d_m_y, " + ", k + " : " + v
            CSVWRITER.writerow(row4)
    for d2_k, d2_v in list(d1.items()):
        for k, v in list(d2[d2_k].items()):
            row5 = d_m_y, " - ", k + " : " + v
            CSVWRITER.writerow(row5)


def mapping_download():
    """
    Function to download the mapping file at project level
    """
    download_file = OPTIONS.download
    download_dir = os.path.join(os.getcwd(), download_file)
    os.makedirs(download_dir)
    XNAT.select('/projects/' + project + '/resources/BIDS_' + mapping_type + '/').get(
                 dest_dir=download_dir)
    LOGGER.info('DOWNLOADED: %s mapping file at %s' % (mapping_type, download_dir))


if __name__ == '__main__':
    parser = add_parser_args()
    OPTIONS = parser.parse_args()
    main_display()
    if OPTIONS.host:
        HOST = OPTIONS.host
    else:
        HOST = os.environ['XNAT_HOST']
    if OPTIONS.username:
        MSG = "Please provide the password for user <%s> on xnat(%s):"
        PWD = getpass.getpass(prompt=MSG % (OPTIONS.username, HOST))
    else:
        PWD = None
    LOGGER = setup_info_logger('BidsMapping', OPTIONS.logfile)
    MSG = 'INFO: connection to xnat <%s>:' % (HOST)
    LOGGER.info(MSG)
    XNAT = XnatUtils.get_interface(host=OPTIONS.host,
                                 user=OPTIONS.username,
                                 pwd=PWD)
    date = datetime.now()
    d_m_y = date.strftime("%m-%d-%y-%H:%M:%S")

    if not OPTIONS.project:
        LOGGER.info("WARNING: Project is require. Use --project")
    else:
        project = OPTIONS.project
        # Put the XNAT type info at project level in xnat_type.txt
        xnat_type_file = XNAT.select('/data/projects/' + project + '/resources/BIDS_xnat_type/files/xnat_type.txt')
        if not OPTIONS.xnatinfo:
            LOGGER.info("WARNING: Xnatinfo is required. Use --xnatinfo")
        elif OPTIONS.xnatinfo not in ["scan_type", "series_description"]:
            LOGGER.error("ERROR: Type must be scan_type or series_description")
        else:
            xnat_type = OPTIONS.xnatinfo
            LOGGER.info("The info used from XNAT is " + xnat_type)
            xnat_type_file.put(src=xnat_type, overwrite=True)
            XNAT.select(os.path.join('/data/projects/' + project + '/resources/', 'LOGFILE.txt'))
            if OPTIONS.template:
                template()
                sys.exit()
            if not OPTIONS.type:
                LOGGER.info("WARNING: Type is required. Use --type")
            elif OPTIONS.type not in ["datatype", "tasktype", "repetition_time_sec", "run_number","asltype"]:
                LOGGER.error("ERROR: Type must be datatype or tasktype or repetition_time_sec or run_number or asltype")
            else:
                mapping_type = OPTIONS.type
                CSVWRITER = None
                res_files = '/data/projects/' + project + '/resources/BIDS_' + mapping_type + '/files'
                xnat_logfile = XNAT.select(os.path.join(res_files, 'LOGFILE.txt'))
                if not xnat_logfile.exists():
                    # If LOGFILE does not exist revert cant be done
                    if OPTIONS.revert:
                        LOGGER.error('Cannot perform --revert. Create mapping with --create option first')
                        sys.exit()
                    xnat_logfile.put(src="Logfile\n")
                LOGFILE = xnat_logfile.get()
                csvfilewrite = open(LOGFILE, 'a')
                CSVWRITER = csv.writer(csvfilewrite, delimiter=',')
                if OPTIONS.create:
                    create()
                if OPTIONS.create_default:
                    default()
                if OPTIONS.update or OPTIONS.replace:
                    update_or_replace()
                if OPTIONS.revert:
                    revert()
                if OPTIONS.download:
                    mapping_download()

                csvfilewrite.close()
                XNAT.select(os.path.join(res_files, 'LOGFILE.txt')).put(src=LOGFILE, overwrite=True)

