#!/usr/bin/env python
import argparse
import logging
import os

import nibabel as nib
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler

import react
from react.utils import check_can_write_file, volume4d_to_matrix

OUT_REACT_STAGE1 = '_react_stage1.txt'
OUT_REACT_STAGE2 = '_react_stage2.nii.gz'
OUT_REACT_STAGE2_IC = '_react_stage2_IC'

DESCRIPTION = f"""
v{react.__version__} Receptor-Enriched Analysis of Functional Connectivity by
Targets. All files must be in the same space.
"""
PROG = 'react'
EPILOG = 'REFERENCE: https://doi.org/10.1016/j.neuroimage.2019.04.007 - ' \
         'Dipasquale, O., Selvaggi, P., Veronese, M., Gabay, A. S., ' \
         'Turkheimer, F., & Mehta, M. A. (2019). Receptor-Enriched Analysis ' \
         'of functional connectivity by targets (REACT): A novel, multimodal ' \
         'analytical approach informed by PET to study the pharmacodynamic ' \
         'response of the brain under MDMA. Neuroimage, 195, 252-260.'


def get_parsed_args():
    parser = argparse.ArgumentParser(prog=PROG, epilog=EPILOG,
                                     description=DESCRIPTION)

    parser.add_argument(
        'in_fmri',
        type=str,
        help='Input file name, '
             'E.g., `/home/study/data/subject001_fmri.nii.gz`'
    )

    parser.add_argument(
        'mask_stage1',
        type=str,
        help='Filename of mask for thresholding '
             'in the first GLM (estimated in react_masks.py). '
             'E.g., `/home/study/data/mask_stage1.nii.gz`'
    )

    parser.add_argument(
        'mask_stage2',
        type=str,
        help='File name of mask for thresholding '
             'in the second GLM (estimated in react_masks.py). '
             'E.g., `/home/study/data/mask_stage2.nii.gz`'
    )

    parser.add_argument(
        'pet_atlas',
        type=str,
        help='3D or 4D file containing the PET atlases '
             'to be used in the REACT analysis. '
             'It is recommended to rescale each PET atlas '
             'between 0 and 1 before running REACT. '
             'E.g., `/home/study/data/PETatlas.nii.gz`'
    )

    parser.add_argument(
        'out_react',
        type=str,
        help='Prefix of the output files. '
             'The output files will be named as: '
             '`<out_react>_react_stage1.txt`, '
             '`<out_react>_react_stage2.nii.gz` '
             'and `<out_react>_react_stage2_IC????.nii.gz` .'
             'E.g., `/home/study/REACT/subject001` will generate files '
             '`/home/study/REACT/subject001_react_stage1.txt`, '
             '`/home/study/REACT/subject001_react_stage1.txt` and '
             '`/home/study/REACT/subject001_react_stage2_IC????.nii.gz`'
    )

    parser.add_argument(
        '--data_norm',
        action='store_true',
        help='If set, normalizes the input data of stage 2 to unit standard '
             'deviation.'
    )

    parser.add_argument(
        '--force',
        action='store_true',
        help="Overwrite existing files."
    )

    parser.add_argument(
        '-v', '--verbose',
        dest='verbose',
        action='store_true',
        help='Set verbose output.'
    )

    return parser.parse_args()


def main():
    args = get_parsed_args()

    if args.verbose:
        logging.getLogger().setLevel(logging.INFO)
    else:
        logging.getLogger().setLevel(logging.WARNING)

    if args.force:
        logging.warning('Overwriting existing files')

    fpath_fmri = args.in_fmri
    fpath_pet = args.pet_atlas
    fpath_mask_stage1 = args.mask_stage1
    fpath_mask_stage2 = args.mask_stage2

    fpath_result_stage1 = os.path.abspath(args.out_react) + OUT_REACT_STAGE1
    check_can_write_file(fpath_result_stage1, args.force, args.verbose)

    fpath_result_stage2 = os.path.abspath(args.out_react) + OUT_REACT_STAGE2
    check_can_write_file(fpath_result_stage2, args.force, args.verbose)

    volume_fmri = nib.load(fpath_fmri)
    n_times = volume_fmri.shape[-1]

    volume_pet = nib.load(fpath_pet)
    if volume_pet.shape[:3] != volume_fmri.shape[:3]:
        raise ValueError(f'PET volume has shape {volume_pet.shape[:3]} while '
                         f'rs-fMRI volume has shape {volume_fmri.shape[:3]}. '
                         f'They must be equal.')
    n_pet_maps = 1 if volume_pet.ndim == 3 else volume_pet.shape[3]

    fpath_ic = []
    for i in range(n_pet_maps):
        fpath_ic.append(os.path.abspath(
            args.out_react + OUT_REACT_STAGE2_IC + str(i) + '.nii.gz'))
        check_can_write_file(fpath_ic[-1], args.force, args.verbose)

    mask_stage1 = nib.load(fpath_mask_stage1).get_fdata()
    if mask_stage1.shape != volume_fmri.shape[:-1]:
        raise ValueError(f'Stage 1 mask has shape {mask_stage1.shape}, while '
                         f'rs-fMRI volume has shape {volume_fmri.shape[:-1]}. '
                         f'They must be equal.')
    mask_stage1 = (mask_stage1 > 0).ravel()

    mask_stage2 = nib.load(fpath_mask_stage2).get_fdata()
    if mask_stage2.shape != volume_fmri.shape[:-1]:
        raise ValueError(f'Stage 2 mask has shape {mask_stage1.shape}, while '
                         f'rs-fMRI volume has shape {volume_fmri.shape[:-1]}. '
                         f'They must be equal.')
    mask_stage2 = (mask_stage2 > 0).ravel()

    logging.info(f'Volume shape: {volume_fmri.shape}')
    logging.info(f'N time points: {n_times}')
    logging.info(f'N PET maps: {n_pet_maps}')
    logging.info(f'N voxels in mask stage 1: {np.count_nonzero(mask_stage1)}')
    logging.info(f'N voxels in mask stage 2: {np.count_nonzero(mask_stage2)}')

    # Stage 1
    logging.info('Initiating stage 1')

    rsfmri = volume4d_to_matrix(volume_fmri.get_fdata())
    pet = volume_pet.get_fdata()
    if pet.ndim == 3:
        pet = pet[..., np.newaxis]
    pet = volume4d_to_matrix(pet)

    scaler_y = StandardScaler(with_mean=True, with_std=False)
    y = scaler_y.fit_transform(rsfmri)[mask_stage1, :]

    scaler_x = StandardScaler(with_mean=True, with_std=False)
    x = scaler_x.fit_transform(pet)[mask_stage1, :]

    fit1 = LinearRegression(fit_intercept=True).fit(x, y)
    beta1 = fit1.coef_

    np.savetxt(fpath_result_stage1, beta1)
    logging.info(f'Saved result stage 1 in {fpath_result_stage1}')

    # Stage 2
    logging.info('Initiating stage 2')

    dat_norm = False
    if args.data_norm:
        dat_norm = True
        logging.info('Normalising input data of stage 2 since `--data_norm` '
                     'was specified')
    scaler_y = StandardScaler(with_mean=True, with_std=dat_norm)
    y = scaler_y.fit_transform(rsfmri.T)[:, mask_stage2]

    scaler_x = StandardScaler(with_mean=True, with_std=True)
    x = scaler_x.fit_transform(beta1)

    model2 = LinearRegression(fit_intercept=True)
    fit2 = model2.fit(x, y)

    beta2 = np.zeros((mask_stage2.size, n_pet_maps))
    beta2[mask_stage2] = fit2.coef_
    beta2 = np.reshape(beta2, volume_pet.shape)

    nib.save(nib.Nifti1Image(beta2, affine=volume_pet.affine),
             fpath_result_stage2)
    logging.info(f'Saved result stage 2 in {fpath_result_stage1}')

    for i, f in enumerate(fpath_ic):
        nib.save(nib.Nifti1Image(np.squeeze(beta2[..., i]),
                                 affine=volume_pet.affine), f)
        logging.info(f'Saved result IC{i} of stage 2 in {f}')


if __name__ == '__main__':
    main()
