#!/usr/bin/env python3

"""
Author: Shadi Zabad
Date: May 2022

This is a commandline script that enables users to generate
polygenic scores for test samples given effect size estimates from VIPRS
"""

import os.path as osp
import argparse
import viprs as vp
from magenpy.utils.system_utils import makedir, get_filenames
from magenpy.GWADataLoader import GWADataLoader
from viprs.model.PRSModel import PRSModel

print(f"""
        **********************************************
                    _____
            ___   _____(_)________ ________________
            __ | / /__  / ___  __ \__  ___/__  ___/
            __ |/ / _  /  __  /_/ /_  /    _(__  )
            _____/  /_/   _  .___/ /_/     /____/
                          /_/
        Variational Inference of Polygenic Risk Scores
        Version: {vp.__version__} | Release date: June 2022
        Author: Shadi Zabad, McGill University
        **********************************************
        < Compute polygenic scores for test samples >
""")

parser = argparse.ArgumentParser(description="""
Commandline arguments for generating polygenic scores
""")

parser.add_argument('-f', '--fit-files', dest='fit_files', type=str, required=True,
                    help='The path to the file(s) with the output parameter estimates from VIPRS. '
                         'You may use a wildcard here (e.g. "prs/chr_*.fit")')
parser.add_argument('--bed-files', dest='bed_files', type=str, required=True,
                    help='The BED files containing the genotype data. '
                         'You may use a wildcard here (e.g. "data/chr_*.bed")')
parser.add_argument('--output-file', dest='output_file', type=str, required=True,
                    help='The output file where to store the polygenic scores (with no extension).')

parser.add_argument('--temp-dir', dest='temp_dir', type=str, default='temp',
                    help='The temporary directory where to store intermediate files.')

parser.add_argument('--keep', dest='keep', type=str,
                    help='A plink-style keep file to select a subset of individuals for the test set.')
parser.add_argument('--extract', dest='extract', type=str,
                    help='A plink-style extract file to select a subset of SNPs for scoring.')
parser.add_argument('--backend', dest='backend', type=str, default='xarray',
                    choices={'xarray', 'plink'},
                    help='The backend software used for computations with the genotype matrix.')
parser.add_argument('--n-threads', dest='n_threads', type=int, default=1,
                    help='The number of threads to use for computations.')
parser.add_argument('--compress', dest='compress', action='store_true', default=False,
                    help='Compress the output file')

args = parser.parse_args()

# ----------------------------------------------------------
print('{:-^62}\n'.format('  Parsed arguments  '))

for key, val in vars(args).items():
    if val is not None and val != parser.get_default(key):
        print("--", key, ":", val)

# ----------------------------------------------------------
print('\n{:-^62}\n'.format('  Reading input data  '))

test_data = GWADataLoader(args.bed_files,
                          keep_file=args.keep,
                          extract_file=args.extract,
                          min_mac=None,
                          min_maf=None,
                          backend=args.backend,
                          temp_dir=args.temp_dir,
                          n_threads=args.n_threads)
prs_m = PRSModel(test_data)

fit_files = get_filenames(args.fit_files, extension='.fit')

if len(fit_files) < 1:
    raise FileNotFoundError("Did not find any parameter fit files at:", args.fit_files)

prs_m.read_inferred_parameters(fit_files)

# ----------------------------------------------------------
print('\n{:-^62}\n'.format('-'))
# Predict on the test set:
print("> Generating polygenic scores...")
prs = test_data.score(prs_m.get_posterior_mean_beta())

# Save the PRS as a table:

ind_table = test_data.to_individual_table().copy()
ind_table['PRS'] = prs

# Clean up all the intermediate files/directories
test_data.cleanup()

print("\n>>> Writing the polygenic scores to:\n", osp.dirname(args.output_file))

# If the user wants the files to be compressed, append `.gz` to the name:
c_ext = ['', '.gz'][args.compress]

# Output the scores:
makedir(osp.dirname(args.output_file))
ind_table.to_csv(args.output_file + '.prs' + c_ext, index=False, sep="\t")
