#!/usr/bin/env python

# Copyright 2020 KCL-BMEIS - King's College London
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

from datetime import datetime, timezone
import os
import sys
import h5py
import re

try:
    import exetera
except ModuleNotFoundError:
    fixed_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    sys.path.append(fixed_path)
    import exetera

from exetera.core import importer
# from exetera.covidspecific import postprocess

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--version', action='version', version=exetera.__version__)

    subparsers = parser.add_subparsers(dest='command')

    parser_import = subparsers.add_parser('import')
    parser_import.add_argument('-te', '--territories', default=None,
                            help=('The territory/territories to filter the dataset on '
                                    '(runs on all territories if not set)'))
    parser_import.add_argument('-s', '--schema', required=True,
                            help='The location and name of the schema file')
    parser_import.add_argument('-i', '--inputs', required=True)
    parser_import.add_argument('-o', '--output_hdf5',
                            help='The location and name of the output hdf5 file. If this is an existing file, it will'
                                    ' be overwritten')
    parser_import.add_argument('-w', '--overwrite', action='store_true',
                            help="If set, overwrites an existing project rather than appending to it")
    parser_import.add_argument('-ts', '--timestamp', default=str(datetime.now(timezone.utc)),
                            help='Override for the import datetime (the current time is selected otherwise)')
    parser_import.add_argument('-n','--include', 
                            help='filters out all fields apart from those in the list')
    parser_import.add_argument('-x','--exclude', 
                            help='filters out the fields in this list')


    # parser_postprocess = subparsers.add_parser('process')
    # parser_postprocess.add_argument('-i', '--input', required=True, help='The dataset to load')
    # parser_postprocess.add_argument('-o', '--output', required=True,
    #                                 help='The dataset to write results to. If this is an existing file, it will be'
    #                                      ' overwritten')
    # parser_postprocess.add_argument('-d', '--daily', action='store_true',
    #                                 help='If set, generate daily assessments from assessments')

    args = parser.parse_args()

    if 'dev' in exetera.__version__:
        msg = ("Warning: this is a development version of exetera ({}). "
            "Please use one of the release versions for actual work")
        print(msg.format(exetera.__version__))

    # TODO: a proper mechanism to register commands / handlers for commands
    if args.command == 'import':
        errors = False
        if not os.path.isfile(args.schema):
            print('-s/--schema argument must be an existing file')
            errors = True
        inputs = args.inputs.split(',')
        tokens = [i.strip() for i in inputs]
        if any(':' not in t for t in tokens):
            raise ValueError("'-i/--inputs': must be composed of a comma-separated list of name:file")
        tokens = {t[0]: t[1] for t in (t.split(':', 1) for t in tokens)}
        print(tokens)
        for tk, tv in tokens.items():
            if not os.path.isfile(tv):
                print("-i/--import_data - {}: '{}' must be an existing file".format(tk, tv))
                errors = True

        include, exclude, include_table_names, exclude_table_names = {},{}, set(), set()     
        err_msg = "'{}': must be composed of a comma-separated list of name:(field1,field2,field3)"
        if args.include:
            isValid, include, include_table_names = validate_and_extract_arg_include_exclude(args.include) 
            if not isValid:
                raise ValueError(err_msg.format('-n/--include'))        

        if args.exclude:
            isValid, exclude, exclude_table_names = validate_and_extract_arg_include_exclude(args.exclude) 
            if not isValid:
                raise ValueError(err_msg.format('-x/--exclude'))    

        if args.include and args.exclude:
            table_appear_twice = include_table_names & exclude_table_names
            if len(table_appear_twice) > 0:
                raise ValueError("'-n/--include, -x/--exclude': tables cannot be in {} and {} simultaneously. Please check the following tables: {}"
                                    .format(args.include, args.exclude, table_appear_twice))


        if errors:
            exit(-1)

        importer.import_with_schema(args.timestamp, args.output_hdf5, args.schema, tokens, args.overwrite, include, exclude)
        
# elif args.command == 'process':
#     timestamp = str(datetime.now(timezone.utc))
#
#     flags = set()
#     if args.daily is True:
#         flags.add('daily')
#
#     with h5py.File(args.input, 'r') as ds:
#         with h5py.File(args.output, 'w') as ts:
#             postprocess.postprocess(ds, ts, timestamp, flags)

def validate_and_extract_arg_include_exclude(fields):
    isValid = False
    result = {}
    if fields.count(':') == fields.count('(') == fields.count(')'):
        count = fields.count(':')
        group = re.findall(r"(\w+)\s*\:\s*\(([\w+\s,*]+)\)*", fields)
        if len(group) == count:
            isValid = True
            result = {table_name: [i.strip() for i in table_fields.split(',')] for table_name, table_fields in group}
            table_names = set([table_name for table_name, _ in group])
    return isValid, result, table_names

    
if __name__ == "__main__":
    main()