#!/scratch/boydb1/apps/miniconda2/bin/python
# -*- coding: utf-8 -*-

'''
Upload demographic data to a project on XNAT.

Created on May 6, 2013
Edited on February 19, 2015
Edited on February 20, 2017

@author: Benjamin Yvernault, Electrical Engineering, Vanderbilt University
'''

import csv
import os

from dax import XnatUtils
from dax.errors import XnatToolsUserError
import dax.xnat_tools_utils as utils


__copyright__ = 'Copyright 2013 Vanderbilt University. All Rights Reserved'
__exe__ = os.path.basename(__file__)
__author__ = 'byvernault'
__purpose__ = 'Upload demographic data to XNAT from a csv file'
__description__ = """What is the script doing :
   * Upload demographic data to a project on XNAT from a CSV file giving to \
the script as an input.
     Demographic data can only be upload on a SUBJECT or SESSION with this \
tool.
     Default parameters that can be uploaded to XNAT: handedness/gender/yob\
/age/scanner/scanner_manufacturer/scanner_model/acquisition_site.
     You can also upload your specific parameters but you need to create \
them first on XNAT.
     By default, specific parameters are supposed to be part of the subject \
level.

Examples:
   *Upload demographic data handedness/gender/age:
        XnatDemographic -i demographic_data.csv
   *See the report before uploading:
        XnatDemographic -i demographic_data.csv
   *Upload demographic data with a file delimited by a /:
        XnatDemographic -i demographic_data.csv --delimiter="/"
   *Upload demographic data with a different header that the default one:
        XnatDemographic -i demographic_data.csv --format=project_id,\
subject_label,session_label,race,handedness,gender,age
   *Upload demographic data with a different header (and session variables):
    XnatDemographic -i demographic_data.csv --format=project_id,subject_label,\
session_label,gender,MCR,state\
--sessformat MCR,state
"""

REQUIRED_VARIABLES_LIST = ['project_id', 'subject_label', 'session_label']
SUBJECT_PARAMETERS_LIST = ['handedness', 'gender', 'yob']
SESSION_PARAMETERS_LIST = ['age', 'scanner', 'scanner_manufacturer',
                           'scanner_model', 'acquisition_site']


def get_yob_from_label(yob, row_index):
    """
    Method to get the passed year of birth in XNAT format

    :param yob: year of birth selected
    :param row_index: row index where an issue occurred
    :return: value if proper integer, last digit if date, '' if not an integer
    """
    if '/' in yob:
        msg = ' warning: row %s -- "/" detected in the year of birth (yob) \
value. Using last digit of the string: %s'
        print(msg % (str(row_index + 1), yob.split('/')[-1]))
        yob = yob.split('/')[-1]

    try:
        int(yob)
        return yob
    except Exception:
        msg = ' warning: row %s -- year of birth (yob) not a proper integer.'
        print(msg % (str(row_index + 1)))
        return ''


def read_csv(csv_file, session_vars, header_format, delimiter):
    """
    Method to read the csv passed as an argument and return
     a list of dictionaries containing the value to set and the header

    :param csv_file: csv file containing the demographics to upload
    :param session_vars: list of sessions variables
    :param header_format: header format for the csv
    :param delimiter: delimiter for csv
    :return: list of dictionaries containing values to set, header
    """
    demo_list = list()
    header = list()
    session_variables = list()
    # Get the sessions variables:
    if session_vars:
        session_variables = set(session_vars.split(',') +
                                SESSION_PARAMETERS_LIST)
    else:
        session_variables = SESSION_PARAMETERS_LIST

    msg = 'Warning: the variables for a session are the following : %s'
    print(msg % (','.join(session_variables)))

    # Header if set:
    if header_format:
        print('Warning: the header has been set by the options --format.')
        header = header_format.split(',')
    else:
        print('Warning: the header has not been set. Using the first line for \
the header.')
    print('INFO: Reading CSV -- if you see any warning, it means that the \
information on the row will not be uploaded.')
    with open(csv_file, 'rt') as csvfileread:
        csvreader = csv.reader(csvfileread, delimiter=delimiter)
        for index, row in enumerate(csvreader):
            if index == 0:  # possible header, check it out
                if not header:
                    header = row
                else:
                    demo_list.append(get_row_csv(row, header, index,
                                                 session_variables))
                # Check the header
                if 'project_id' not in header or 'subject_label' not in header:
                    msg = ' warning: row %s -- "project_id" and \
"subject_label" not in the header. Check the header.'
                    print(msg % (str(index + 1)))
                    return None, None
                if 'age' in header and 'session_label' not in header:
                    msg = ' warning: row %s -- age found in header but \
session_label not in it. Check the header.'
                    print(msg % (str(index + 1)))
            else:
                demo_list.append(
                    get_row_csv(row, header, index, session_variables))

    # Remove empty dictionaries:
    return [x for x in demo_list if x], header


def get_row_csv(row, header, row_index, session_variables):
    """
    Method to extract the dictionary from a row in the csv

    :param row: row to convert
    :param header: header to use for the dictionary
    :param row_index: index of the row processed for warnings/errors
    :param session_variables: if the variables is at the session level
    :return: dictionary of the row
    """
    tmp_dict = dict()
    if len(row) != len(header):
        print(' WARNING: row %s -- less or more columns than the header. \
Check the csv file.' % (str(row_index + 1)))
        return tmp_dict
    else:
        tmp_dict['upload_demo_subject'] = dict()
        tmp_dict['upload_demo_session'] = dict()
        for index, val in enumerate(row):
            _key = header[index]
            if _key in REQUIRED_VARIABLES_LIST:
                tmp_dict[_key] = val
            elif val:
                if _key in session_variables:
                    tmp_dict['upload_demo_session'][_key] = val
                else:
                    _dict = tmp_dict['upload_demo_subject']
                    if header[index] == 'handedness':
                        _dict[_key] = utils.get_handedness_from_label(val)
                    elif header[index] == 'gender':
                        _dict[_key] = utils.get_gender_from_label(val)
                    elif header[index] == 'yob':
                        _dict[_key] = get_yob_from_label(val, row_index)
                    else:
                        _dict[_key] = val
        return tmp_dict


def print_report(demographics, header):
    """
    Method to print(the report on the upload of demographic data

    :param demographics: list of demographic to upload (dicts)
    :param header: list of field names for header
    :return: None
    """
    # Display:
    print("""Report information :
List of the data that need to be upload for the header:
%s""" % ','.join(header))
    utils.print_separators(symbol='-')
    new_list = sorted(demographics, key=lambda k: k['project_id'])
    for obj_dict in sorted(new_list):
        if obj_dict:
            row = get_row_info(obj_dict)
            print(','.join(row))
    utils.print_separators(symbol='-')


def get_row_info(obj_dict, header):
    """
    Method to convert information from dictionary into a string

    :param obj_dict: dictionary to convert
    :param header: list of field names for header
    :return: string representing row to display
    """
    row = list()
    for field in header:
        if field in list(obj_dict.keys()):
            row.append(obj_dict.get(field, ''))
        else:
            if field in list(obj_dict['upload_demo_subject'].keys()):
                row.append(obj_dict['upload_demo_subject'].get(field, ''))
            elif field in list(obj_dict['upload_demo_session'].keys()):
                row.append(obj_dict['upload_demo_session'].get(field, ''))
            else:
                row.append('')
    return row


def print_info_row(object_dict):
    """
    Method to print(information when uploading

    :param obj_dict: dictionary to convert
    :return: None
    """
    _tmp = ' *Project: %s  / Subject: %s'
    _args = [object_dict['project_id'], object_dict['subject_label']]
    if 'session_label' in list(object_dict.keys()):
        _tmp += ' / Session: %s'
        _args.append(object_dict['session_label'])

    if 'upload_demo_subject' in list(object_dict.keys()):
        for tag, value in list(object_dict['upload_demo_subject'].items()):
            _tmp += ' / %s: %s'
            _args.append(tag)
            _args.append(value)

    if 'upload_demo_session' in list(object_dict.keys()):
        for tag, value in list(object_dict['upload_demo_session'].items()):
            _tmp += ' / %s: %s'
            _args.append(tag)
            _args.append(value)

    print(_tmp % tuple(_args))


def print_format():
    """
    Printing the variables available for the csv as subject/session level.
     Displaying the variables required as well in the csv.

    :return: None
    """
    _format = ' * %*s '
    msg = """INFO: Printing the variables available for the csv as subject/ \
session level (required as well in the csv to know where to set the \
information specified by the user):
 Required variables:"""
    print(msg)
    for value in REQUIRED_VARIABLES_LIST:
        print(_format % (-30, value))
    print('Default Subject demographic attributes:')
    for value in SUBJECT_PARAMETERS_LIST:
        print(_format % (-30, value))
    print('Default Session demographic attributes:')
    for value in SESSION_PARAMETERS_LIST:
        print(_format % (-30, value))
    msg = """INFO: You can generate your own attributes by creating them on \
the XNAT directly and use the script to upload the values.

WARNING:
    * you need to have "project_id,subject_label" if you upload demographic \
data to a subject.
    * for session attributes, you need to give the session_label as well.
    * for specific attributes that you created on XNAT, please put the name \
of the variables in the header."""
    print(msg)


def upload_demographic_data(xnat, demographics):
    """
    Main Method to upload demographic data to XNAT

    :param xnat: pyxnat.interface object
    :param demographics: list of demographic to upload (dicts)
    :return: None
    """
    for obj_dict in sorted(demographics, key=lambda k: k['project_id']):
        subject_obj = xnat.select_subject(
            project=obj_dict['project_id'],
            subject=obj_dict['subject_label']
        )

        if not subject_obj.exists():
            msg = " --> WARNING: Subject %s doesn't exist. No information \
will be uploaded."
            print(msg % (obj_dict['subject_label']))
        else:
            print_info_row(obj_dict)
            update_demo_data(subject_obj, obj_dict)


def update_demo_data(subject_obj, obj_dict):
    """
    Method to update demographic data for a specified subject

    :param subject_obj: pyxnat.Interface subject Object
    :param obj_dict: list of values to upload
    :return: None
    """
    if 'upload_demo_subject' in list(obj_dict.keys()):
        update_tags_subject(subject_obj, obj_dict['subject_label'],
                            obj_dict['upload_demo_subject'])

    if 'upload_demo_session' in list(obj_dict.keys()) and \
       'session_label' in list(obj_dict.keys()):
        session_obj = subject_obj.experiment(obj_dict['session_label'])
        update_tags_session(session_obj, obj_dict['session_label'],
                            obj_dict['upload_demo_session'])


def update_tags_subject(subject_obj, subject_label, subj_dict):
    """
    Method to set a tag for a subject

    :param subject_obj: pyxnat.Interface subject Object
    :param subject_label: subject label on XNAT
    :param subj_dict: dictionary of tag and value to set for the subject
    :return: None
    """
    if len(subj_dict) == 1:
        tag, value = list(subj_dict.items())[0]
        if subject_obj.attrs.get(tag):
            print('already set:'+tag)
        else:
            print('setting {} to {}'.format(tag, value))
            subject_obj.attrs.set(tag, value)
    else:
        mset_dict = dict()
        for tag, value in list(subj_dict.items()):
            if tag in SUBJECT_PARAMETERS_LIST:  # default attributes
                key_tmp = 'xnat:subjectData/demographics[@xsi:type=xnat:\
    demographicData]/%s'
            else:
                key_tmp = 'xnat:subjectData/fields/field[name=%s]/field'
            mset_dict[key_tmp % tag.lower()] = value

        subject_obj.attrs.mset(mset_dict)
        print("  - %s set on subject." % (', '.join(list(subj_dict.keys()))))


def update_tags_session(session_obj, session_label, sess_dict):
    """
    Method to set a tag for a specified session

    :param subject_obj: pyxnat.Interface session Object
    :param subject_label: session label on XNAT
    :param sess_dict: dictionary of tag and value to set for the session
    :return: None
    """

    if not session_obj.exists():
        msg = "  --> warning: Session %s doesn't exist or not set."
        print(msg % (session_label))
    elif len(sess_dict) == 1:
        xsitype_sess = session_obj.attrs.get('xsiType')
        tag, value = list(sess_dict.items())[0]
        if tag in SESSION_PARAMETERS_LIST:
            tag = '{}/{}'.format(xsitype_sess, tag)
        else:
            tag = '{}/fields/field[name=%s]/field/{}'.format(xsitype_sess, tag)
        if session_obj.attrs.get(tag):
            print('already set:'+tag)
        else:
            print('setting {} to {}'.format(tag, value))
            session_obj.attrs.set(tag, value)
    else:
        mset_dict = dict()
        xsitype_sess = session_obj.attrs.get('xsiType')
        for tag, value in list(sess_dict.items()):
            if tag in SESSION_PARAMETERS_LIST:
                key_tmp = '%s/%s'
            else:
                key_tmp = '%s/fields/field[name=%s]/field'
            mset_dict[key_tmp % (xsitype_sess, tag.lower())] = value

        session_obj.attrs.mset(mset_dict)
        print("  - %s set on session." % (', '.join(list(sess_dict.keys()))))


def run_xnat_demographics(args):
    """
    Main function for xnat demographics.

    :param args: arguments parse by argparse
    """
    utils.print_separators()
    if args.printformat:
        print_format()
        utils.print_separators()

    if args.csvfile:
        if not os.path.exists(args.csvfile):
            err = "the csv file '%s' does not exist."
            raise XnatToolsUserError(__exe__, err % args.csvfile)

    demographics, header = read_csv(args.csvfile, args.session_format,
                                    args.format, args.delimiter)
    if not demographics:
        print('INFO: no valid demographic data found.')
    elif args.report:
        print_report(demographics, header)
    else:
        if args.host:
            host = args.host
        else:
            host = os.environ['XNAT_HOST']
        user = args.username
        with XnatUtils.get_interface(host=host, user=user) as xnat:
            print('INFO: connection to xnat <%s>:' % (host))
            msg = 'INFO: Uploading demographics data to XNAT <%s>. It \
will take some time to upload demographic data. Please be patient.'
            print(msg % (host))
            upload_demographic_data(xnat, demographics)

    utils.print_end(__exe__)


def add_to_parser(parser):
    """
    Method to add arguments to default parser for xnat_tools in utils.

    :param parser: parser object
    :return: parser object with new arguments
    """
    _h = "File path as inputs that will be read for XNAT information. \
Default header used: project_id,subject_label,session_label,handedness,\
gender,age."
    parser.add_argument("-c", "--csv", dest="csvfile", default=None, help=_h)
    parser.add_argument("-d", "--delimiter", dest="delimiter", default=',',
                        help="Delimiter for the output file. Default: comma.")
    _h = "Header for the csv. format: list of variables name separated by a \
comma. Default: using first line in the csv for the header."
    parser.add_argument("--format", dest="format", default=None, help=_h)
    _h = "List of variables for the session from the header. format: \
comma separated list."
    parser.add_argument("--sessformat", dest="session_format", default=None,
                        help=_h)
    _h = "Show what information the script will upload to XNAT."
    parser.add_argument("--report", dest="report", action="store_true",
                        default=False, help=_h)
    parser.add_argument("--printformat", dest="printformat",
                        action='store_true',
                        help="Print available parameters for the csv header.")
    return parser


if __name__ == '__main__':
    utils.run_tool(
        __exe__, __description__, add_to_parser, __purpose__,
        run_xnat_demographics)
