#!/usr/bin/env python
from argparse import ArgumentParser
from smnsr.patients import AugmentedTADPOLEData, TADPOLEData
from sklearn.model_selection import KFold

cli = ArgumentParser()
cli.add_argument("--modality_path", type=str, default="../../modalities/")
cli.add_argument("--folds", type=int, default=10)
cli.add_argument("--fold_prefix", type=str, default="fold_")
cli.add_argument("--merge_prefix", type=str, default="merged_")
cli.add_argument("--extension", type=str, default=".p")
cli.add_argument("--fold_path", type=str, default="../../output/precomputed_folds/")
cli.add_argument("--output_path", type=str, default="../../output/")
cli.add_argument("--modality_k", type=int, default=8)

args = cli.parse_args()


data = TADPOLEData(
    modality_path=args.modality_path, modality_k=args.modality_k, verbosity=0
)
ptids = data.get_ptids()

if args.folds == 0:
    augmented_data = AugmentedTADPOLEData(
        data, args.fold_path + args.fold_prefix + args.extension, ptids, verbosity=2
    )
    augmented_data.save(
        args.output_path + args.merge_prefix + "all" + args.extension, overwrite=True
    )
else:
    splitter = KFold(n_splits=args.folds, shuffle=True, random_state=0)
    for fold, (train_index, test_index) in enumerate(splitter.split(ptids)):
        print("Merging fold %i" % fold)
        train_ptids = [ptids[i] for i in train_index]
        augmented_data = AugmentedTADPOLEData(
            data,
            args.fold_path + args.fold_prefix + str(fold) + args.extension,
            train_ptids,
            verbosity=0,
        )
        augmented_data.save(
            args.output_path + args.merge_prefix + str(fold) + args.extension,
            overwrite=True,
        )
