#!/usr/bin/env python
import pickle
import bz2
from sklearn.model_selection import KFold
from smnsr.patients.timeseries_creation import create_features
from smnsr.patients import TADPOLEData
import psutil
import ray
from sklearn.model_selection import KFold
import sys
import argparse

cli = argparse.ArgumentParser()
cli.add_argument("-m", "--modality_path", type=str, default="../../modalities/")
cli.add_argument("-o", "--out", type=str, default="../../output/")
cli.add_argument("-n", "--name", type=str, default="fold_")
cli.add_argument("-e", "--extension", type=str, default=".p")
cli.add_argument("-f", "--folds", type=int, default=10)

cli.add_argument("-r", "--range", nargs="+", type=int)

cli.add_argument("-c", "--cpus", type=int, default=0)

cli.add_argument("-k", "--modality_k", type=int, default=2)

args = cli.parse_args()
data = TADPOLEData(modality_path=args.modality_path, modality_k=args.modality_k)
# ts_features = create_features(data,data.get_ptids(),modalities=["cognitive1"])

print("Initializing ray")
num_cpus = args.cpus
if num_cpus < 1:
    num_cpus = psutil.cpu_count(logical=False)

sys.stdout.flush()
ray.shutdown()
ray.init(num_cpus=num_cpus)
print("Begin fold creation")
sys.stdout.flush()
if args.folds > 1:
    splitter = KFold(n_splits=args.folds, shuffle=True, random_state=0)
    ptids = data.get_ptids()

    for fold, (train_index, test_index) in enumerate(splitter.split(ptids)):

        if args.range is not None:
            if args.range[0] > fold:
                continue
            if args.range[1] < fold:
                break

        print("Fold %i/%i" % (fold, args.range[1]))
        sys.stdout.flush()

        train_ptids = [ptids[i] for i in train_index]
        test_ptids = [ptids[i] for i in test_index]
        ts_features, knn_models = create_features(
            data, train_ptids, test_ptids, num_cpus=num_cpus
        )
        with bz2.BZ2File(
            args.out + args.name + str(fold) + args.extension, "w"
        ) as file:
            pickle.dump(ts_features, file)
else:
    ts_features, knn_models = create_features(
        data, data.get_ptids(), [], num_cpus=num_cpus
    )
    with bz2.BZ2File(args.out + args.name + args.extension, "w") as file:
        pickle.dump(ts_features, file)
