#!/Users/alynch/projects/.venvs/lisa_test/bin/python3

from lisa import LISA, Log, _config, __version__
import configparser
import argparse
import os
import sys

#____COMMAND LINE INTERFACE________

INSTANTIATION_KWARGS = ['cores','isd_method','verbose','oneshot']
PREDICTION_KWARGS = ['background_list','num_background_genes','background_strategy', 'seed']

def extract_kwargs(args, keywords):
    return {key : vars(args)[key] for key in keywords}

def is_valid_prefix(prefix):

    if os.path.isdir(prefix) or os.path.isfile(prefix) or os.path.isdir(os.path.dirname(prefix)):
        return prefix
        
    raise argparse.ArgumentTypeError('{}: Invalid file prefix.'.format(prefix))

def lisa_download(args):
    lisa = LISA(args.species)
    lisa._download_data()


def lisa_oneshot(args):

    results, metadata = LISA(args.species, **extract_kwargs(args, INSTANTIATION_KWARGS)).predict(args.query_list.readlines(), **extract_kwargs(args, PREDICTION_KWARGS))
    
    if args.save_metadata:
        if args.output_prefix:
            metadata_filename = args.output_prefix + '.metadata.json' 
        else:
            metadata_filename = os.path.basename(args.query_list.name) + '.metadata.json'

        with open(metadata_filename, 'w') as f:
            f.write(json.dumps(metadata, indent=4))

    if not args.output_prefix is None:
        with open(args.output_prefix + '.lisa.tsv', 'w') as f:
            f.write(results.to_tsv())
    else:
        print(results.to_tsv())


def save_and_get_top_TFs(args, query_name, results, metadata):

    with open(args.output_prefix + query_name + '.lisa.tsv', 'w') as f:
        f.write(results.to_tsv())

    if args.save_metadata:
        with open(args.output_prefix + query_name + '.metadata.json', 'w') as f:
            f.write(json.dumps(metadata, indent=4))

    try:
        top_TFs = results.filter_rows(lambda x : x <= 0.05, 'combined_p_value_adjusted').todict()['factor']
    except KeyError:
        top_TFs = ['None']

    top_TFs_unique, encountered = [], set()
    for TF in top_TFs:
        if not TF in encountered:
            top_TFs_unique.append(TF)
            encountered.add(TF)

    return top_TFs_unique


def print_results_multi(results_summary):
    print('Sample\tTop Regulatory Factors (p < 0.05)')
    for result_line in results_summary:
        print(result_line[0], ', '.join(result_line[1]), sep = '\t')

        
def lisa_multi(args):

    log = Log(target = sys.stderr, verbose = args.verbose)
    lisa = LISA(args.species, **extract_kwargs(args, INSTANTIATION_KWARGS), log = log)

    query_dict = {os.path.basename(query.name) : query.readlines() for query in args.query_lists}

    results_summary = []
    for query_name, query_list in query_dict.items():
    
        with log.section('Modeling {}:'.format(str(query_name))):
            try: 
                results, metadata = lisa.predict(query_list, **extract_kwargs(args, PREDICTION_KWARGS))

                top_TFs_unique = save_and_get_top_TFs(args, query_name, results, metadata)
            
                results_summary.append((query_name, top_TFs_unique))
            
            except AssertionError as err:
                log.append('ERROR: ' + str(err))

    print_results_multi(results_summary)


def lisa_one_v_rest(args):

    log = Log(target = sys.stderr, verbose = args.verbose)
    lisa = LISA(args.species, **extract_kwargs(args, INSTANTIATION_KWARGS), log = log)
    
    queries = {query.name : query.readlines() for query in args.query_lists}

    cluster_lists = {
        os.path.basename(query_name) : (query_genes, [
            background_gene
            for j, background in enumerate(queries.values()) for background_gene in background if not j == i
        ])
        for i, (query_name, query_genes) in enumerate(queries.items())
    }

    results_summary = []
    for query_name, genelists in cluster_lists.items():

        with log.section('Modeling {}:'.format(str(query_name))):
            try:
                results, metadata = lisa.predict(genelists[0], background_list = genelists[1], **extract_kwargs(args, ['background_strategy', 'seed']))

                top_TFs_unique = save_and_get_top_TFs(args, query_name, results, metadata)
            
                results_summary.append((query_name, top_TFs_unique))
            
            except AssertionError as err:
                log.append('ERROR: ' + str(err))

    print_results_multi(results_summary)


def build_common_args(parser):
    parser.add_argument('species', choices = ['hg38','mm10'], help = 'Find TFs associated with human (hg38) or mouse (mm10) genes')
    parser.add_argument('-c','--cores', required = True, type = int, default = 1)
    parser.add_argument('--seed', type = int, default = None, help = 'Random seed for gene selection. Allows for reproducing exact results.')
    parser.add_argument('--use_motifs', action = 'store_const', const = 'motifs', default='chipseq',
        dest = 'isd_method', help = 'Use motif hits instead of ChIP-seq peaks to represent TF binding (only recommended if TF-of-interest is not represented in ChIP-seq database).')
    parser.add_argument('--save_metadata', action = 'store_true', default = False, help = 'Save json-formatted metadata from processing each gene list.')


def build_multiple_lists_args(parser):
    parser.add_argument('query_lists', type = argparse.FileType('r', encoding = 'utf-8'), nargs = "+", help = 'user-supplied gene lists. One gene per line in either symbol or refseqID format')
    parser.add_argument('-o','--output_prefix', required = True, type = is_valid_prefix, help = 'Output file prefix.')
    parser.add_argument('-v','--verbose',type = int, default = 2)

if __name__ == "__main__":

    #define command-line arguments
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description =
"""
Lisa: inferring transcriptional regulators through integrative modeling of public chromatin accessibility and ChIP-seq data\n
https://genomebiology.biomedcentral.com/articles/10.1186/s13059-020-1934-6\n
X. Shirley Liu Lab, 2020\n
""")

    parser.add_argument('--version', action = 'version', version = __version__)

    subparsers = parser.add_subparsers(help = 'commands')

    #__ LISA oneshot command __#################

    oneshot_parser = subparsers.add_parser('oneshot', help = 'Use LISA to infer genes from one gene list. If you have multiple lists, this option will be slower than using "multi" due to data-loading time.\n')
    build_common_args(oneshot_parser)

    oneshot_parser.add_argument('query_list', type = argparse.FileType('r', encoding = 'utf-8'), help = 'user-supplied gene lists. One gene per line in either symbol or refseqID format')
    oneshot_parser.add_argument('-o','--output_prefix', required = False, type = is_valid_prefix, help = 'Output file prefix. If left empty, will write results to stdout.')
    oneshot_parser.add_argument('--background_strategy', choices = _config.get('lisa_params', 'background_strategies').split(','),
        default = 'regulatory',
        help = """Background genes selection strategy. LISA samples background genes to compare to user\'s genes-of-interest from a diverse
        regulatory background (regulatory - recommended), randomly from all genes (random), or uses a user-provided list (provided).
        """)
    background_genes_group = oneshot_parser.add_mutually_exclusive_group()
    background_genes_group.add_argument('--background_list', type = argparse.FileType('r', encoding = 'utf-8'), required = False,
        help = 'user-supplied list of backgroung genes. Used when --background_strategy flag is set to "provided"')
    background_genes_group.add_argument('-b','--num_background_genes', type = int, default = _config.get('lisa_params', 'background_genes'),
        help = 'Number of sampled background genes to compare to user-supplied genes')
    oneshot_parser.add_argument('-v','--verbose',type = int, default = 4)
    oneshot_parser.set_defaults(func = lisa_oneshot, oneshot = True)
    
    #__ LISA multi command __#################

    multi_parser = subparsers.add_parser('multi', help = 'Process multiple genelists. This reduces data-loading time if using the same parameters for all lists.\n')
    build_common_args(multi_parser)
    build_multiple_lists_args(multi_parser)
    multi_parser.add_argument('-b','--num_background_genes', type = int, default = _config.get('lisa_params', 'background_genes'),
        help = 'Number of sampled background genes to compare to user-supplied genes. These genes are selection from other gene lists.')
    multi_parser.add_argument('--random_background', action = 'store_const', const = 'random', default = 'regulatory', dest = 'background_strategy', help = 'Use random background selection rather than "regulatory" selection.')
    multi_parser.set_defaults(func = lisa_multi, oneshot = False, background_list = None)
    
    #__ LISA one-vs-rest command __#################

    one_v_rest_parser = subparsers.add_parser('one-vs-rest', help = 'Compare gene lists in a one-vs-rest fashion. Useful downstream of cluster analysis.\n')
    build_common_args(one_v_rest_parser)
    build_multiple_lists_args(one_v_rest_parser)
    one_v_rest_parser.set_defaults(func = lisa_one_v_rest, oneshot = False, background_strategy = 'provided')

    download_data_parser = subparsers.add_parser('download', help = 'Download data from CistromeDB. Use if data recieved is incomplete or malformed.')
    download_data_parser.add_argument('species', choices = ['hg38','mm10'], help = 'Download data associated with human (hg38) or mouse (mm10) genes')   
    download_data_parser.set_defaults(func = lisa_download)

    args = parser.parse_args()

    try:
        args.func #first try accessing the .func attribute, which is empty if user tries ">>>lisa". In this case, don't throw error, display help!
    except AttributeError:
        print(parser.print_help(), file = sys.stderr)
    else:
        args.func(args)
    