#!/usr/bin/env python3

"""
Author: Shadi Zabad
Date: May 2022

This is a commandline script that enables users to
evaluate the predictive performance of PRS models.

"""

import os.path as osp
import pandas as pd
import argparse

import viprs as vp
from magenpy.utils.system_utils import makedir
from viprs.eval.metrics import *


print(f"""
     **************************************************
                    _____
            ___   _____(_)________ ________________
            __ | / /__  / ___  __ \__  ___/__  ___/
            __ |/ / _  /  __  /_/ /_  /    _(__  )
            _____/  /_/   _  .___/ /_/     /____/
                          /_/
        Variational Inference of Polygenic Risk Scores
        Version: {vp.__version__} | Release date: June 2022
        Author: Shadi Zabad, McGill University
     **************************************************
       < Evaluate the predictive performance of PRS >
""")


parser = argparse.ArgumentParser(description="""
Commandline arguments for evaluating polygenic score estimates
""")

parser.add_argument('--prs-file', dest='prs_file', type=str, required=True,
                    help='The path to the PRS file (expected format: FID IID PRS, tab-separated)')
parser.add_argument('--phenotype-file', dest='pheno_file', type=str, required=True,
                    help='The path to the phenotype file. '
                         'The expected format is: FID IID phenotype (no header), tab-separated.')
parser.add_argument('--phenotype-likelihood', dest='pheno_lik', type=str, default='gaussian',
                    choices={'gaussian', 'binomial'},
                    help='The phenotype likelihood ("gaussian" for continuous, "binomial" for case-control).')
parser.add_argument('--output-file', dest='output_file', type=str, required=True,
                    help='The output file where to store the evaluation metrics (with no extension).')

parser.add_argument('--covariate-file', dest='covariate_file', type=str,
                    help='A file with covariates for the samples included in the analysis. This tab-separated '
                         'file should not have a header and the first two columns should be '
                         'the FID and IID of the samples.')

args = parser.parse_args()

# ----------------------------------------------------------
print('{:-^62}\n'.format('  Parsed arguments  '))

for key, val in vars(args).items():
    if val is not None and val != parser.get_default(key):
        print("--", key, ":", val)

# ----------------------------------------------------------
print('\n{:-^62}\n'.format('-'))

print('> Reading input data...')

prs_df = pd.read_csv(args.prs_file, sep='\t')
pheno_df = pd.read_csv(args.pheno_file, names=['FID', 'IID', 'phenotype'], sep='\t')

merged_df = prs_df.merge(pheno_df, on=['FID', 'IID'])

# ----------------------------------------------------------

print("> Evaluating PRS model performance...")

if args.pheno_lik == 'gaussian':

    metric_dict = {
        'R2': r2(merged_df['PRS'].values, merged_df['phenotype'].values),
        'MSE': mse(merged_df['PRS'].values, merged_df['phenotype'].values),
        'Pearson_correlation': pearson_r(merged_df['PRS'].values, merged_df['phenotype'].values)
    }

    if args.covariate_file is not None:
        covar_df = pd.read_csv(args.covariate_file, header=None, sep='\t')
        covariates = [f'C{i}' for i in range(len(covar_df.columns) - 2)]
        covar_df.columns = ['FID', 'IID'] + covariates
        merged_df = merged_df.merge(covar_df, on=['FID', 'IID'])

        metric_dict['Partial_correlation'] = partial_correlation(merged_df['PRS'].values,
                                                                 merged_df['phenotype'].values,
                                                                 merged_df[covariates])

        metric_dict.update(incremental_r2(merged_df['PRS'].values,
                                          merged_df['phenotype'].values,
                                          merged_df[covariates]))

else:

    metric_dict = {
        'AUROC': roc_auc(merged_df['PRS'].values, merged_df['phenotype'].values),
        'AUPRC': pr_auc(merged_df['PRS'].values, merged_df['phenotype'].values),
        'Avg_precision': avg_precision(merged_df['PRS'].values, merged_df['phenotype'].values)
    }

print("\n>>> Writing the evaluation results to:\n", osp.dirname(args.output_file))

makedir(osp.dirname(args.output_file))
pd.DataFrame([metric_dict]).to_csv(args.output_file + ".eval", sep="\t", index=False)
