# AUTOGENERATED! DO NOT EDIT! File to edit: 03_connectomes.ipynb (unless otherwise specified).

__all__ = ['connectome']

# Internal Cell
import os
import bids
bids.config.set_option('extension_initial_dot', True)

from nipype.pipeline import Node, Workflow

import pipetography.nodes as nodes
import pipetography.core as ppt

# Cell

class connectome:
    """
    Create a pipeline that produces connectomes based on input atlases and streamlines, the pipeline will create sub-graphs based on inputs BIDS directory subject & session combinations.

    Inputs:
         - BIDS_dir (str): base BIDS directory path
         - atlas_list (List of strings): names of atlases: aal, brainnectome, desikan-killiany, default is set to brainnectome for now.
         - debug (bool): Default = False; if True, saves node outputs and log files.
    """

    def __init__(self, BIDS_dir, atlas_list, skip_tuples=[()], debug=False):
        """
        Initialize workflow nodes
        """
        self.bids_dir = BIDS_dir
        self.atlas_list = atlas_list
        self.sub_list, self.ses_list, self.layout = ppt.get_subs(BIDS_dir)
        self.skip_combos = skip_tuples
        self.debug_mode = debug
        self.subject_template = {
            'tck': os.path.join(self.bids_dir, 'derivatives', 'streamlines','sub-{subject_id}', 'ses-{session_id}', 'sub-{subject_id}_ses-{session_id}_gmwmi2wm.tck'),
            'brain': os.path.join(self.bids_dir, 'derivatives', 'pipetography', 'sub-{subject_id}', 'ses-{session_id}', 'preprocessed', 'dwi_space-acpc_res-1mm_seg-brain.nii.gz'),
            'dwi_mif': os.path.join(self.bids_dir, 'derivatives', 'pipetography', 'sub-{subject_id}', 'ses-{session_id}', 'preprocessed', 'dwi_space-acpc_res-1mm.mif'),
            'T1A': os.path.join(self.bids_dir, 'derivatives', 'pipetography', 'sub-{subject_id}', 'ses-{session_id}', 'preprocessed', 'T1w_space-acpc.nii.gz'),
            'mask': os.path.join(self.bids_dir, 'derivatives', 'pipetography', 'sub-{subject_id}', 'ses-{session_id}', 'preprocessed', 'dwi_space-acpc_res-1mm_seg-brain_mask.nii.gz'),
            'mrtrix5tt': os.path.join(self.bids_dir, 'derivatives', 'pipetography', 'sub-{subject_id}', 'ses-{session_id}', 'preprocessed', 'T1w_space-acpc_seg-5tt.mif')
        }


    def create_nodes(self):
        """
        Create postprocessing nodes, and make output path substitutions so outputs are BIDS compliant.
        """
        self.PostProcNodes = nodes.PostProcNodes(
            BIDS_dir=self.bids_dir,
            subj_template = self.subject_template,
            sub_list = self.sub_list,
            ses_list = self.ses_list,
            skip_tuples = self.skip_combos)
        self.PostProcNodes.linear_reg.iterables = [('moving_image', self.atlas_list)]
        self.workflow = None


    def connect_nodes(self, wf_name="connectomes"):
        """
        Connect postprocessing nodes into workflow
        """
        self.workflow = Workflow(name=wf_name, base_dir=os.path.join(self.bids_dir, 'derivatives'))
        self.workflow.connect(
            [
                (self.PostProcNodes.subject_source, self.PostProcNodes.select_files, [('subject_id', 'subject_id'),
                                                                                      ('session_id', 'session_id')]),
                (self.PostProcNodes.select_files, self.PostProcNodes.linear_reg, [('brain', 'fixed_image')]),
                (self.PostProcNodes.linear_reg, self.PostProcNodes.nonlinear_reg, [('warped_image', 'moving_image')]),
                (self.PostProcNodes.select_files, self.PostProcNodes.nonlinear_reg, [('brain', 'fixed_image')]),
                (self.PostProcNodes.nonlinear_reg, self.PostProcNodes.round_atlas, [('warped_image', 'in_file')]),
                (self.PostProcNodes.round_atlas, self.PostProcNodes.connectome, [('out_file', 'in_parc')]),
                (self.PostProcNodes.round_atlas, self.PostProcNodes.distance, [('out_file', 'in_parc')]),
                (self.PostProcNodes.select_files, self.PostProcNodes.response, [('dwi_mif', 'in_file')]),
                (self.PostProcNodes.select_files, self.PostProcNodes.fod, [('dwi_mif', 'in_file')]),
                (self.PostProcNodes.select_files, self.PostProcNodes.fod, [('mask', 'mask_file')]),
                (self.PostProcNodes.response, self.PostProcNodes.fod, [('wm_file', 'wm_txt')]),
                (self.PostProcNodes.response, self.PostProcNodes.fod, [('gm_file', 'gm_txt')]),
                (self.PostProcNodes.response, self.PostProcNodes.fod, [('csf_file', 'csf_txt')]),
                (self.PostProcNodes.select_files, self.PostProcNodes.sift2, [('mrtrix5tt', 'act')]),
                (self.PostProcNodes.select_files, self.PostProcNodes.sift2, [('tck', 'in_file')]),
                (self.PostProcNodes.fod, self.PostProcNodes.sift2, [('wm_odf', 'in_fod')]),
                (self.PostProcNodes.sift2, self.PostProcNodes.connectome, [('out_file', 'in_weights')]),
                (self.PostProcNodes.select_files, self.PostProcNodes.connectome, [('tck', 'in_file')]),
                (self.PostProcNodes.select_files, self.PostProcNodes.distance, [('tck', 'in_file')]),
                (self.PostProcNodes.connectome, self.PostProcNodes.datasink, [('out_file', 'connectomes.@connectome')]),
                (self.PostProcNodes.distance, self.PostProcNodes.datasink, [('out_file', 'connectomes.@distance')])
            ])
        if not self.debug_mode:
            self.workflow.config["execution"] = {
                "use_relative_paths": "True",
                "hash_method": "content",
                "stop_on_first_crash": "True",
            }
        else:
           self.workflow.config["execution"] = {
                "use_relative_paths": "True",
                "hash_method": "content",
                "stop_on_first_crash": "True",
                "remove_node_directories": "True",
            }
    def draw_pipeline(self, graph_type='orig'):
        """
        Visualize workflow
        """
        self.workflow.write_graph(
            graph2use=graph_type,
            dotfilename = os.path.join(
                self.bids_dir, 'derivatives', 'pipetography', 'graph', 'postprocessing.dot'
            ),
        )

    def run_pipeline(self, parallel=None):
        """
        Run nipype workflow
        """
        if type(parallel) == int:
            print("Running workflow with {} parallel processes".format(parallel))
            self.workflow.run('MultiProc', plugin_args = {'n_procs': parallel})
        elif parallel is None:
            print("Parallel processing disabled, running workflow serially")
            self.workflow.run()