# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/02_utils.ipynb (unless otherwise specified).

__all__ = ['globtastic', 'compile_re', 'walk', 'kaggle_braintumor_meta_cols', 'get_dicom_metadata', 'get_patient_id',
           'get_patient_BraTS21ID_path', 'get_all_dicom_metadata', 'get_image_plane']

# Cell
import os
import ast
import wandb
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
# from fastcore.xtras import globtastic

# pydicom related imports
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

# kagglerecipes imports
from .data import TINY_DATA_PATH

# Cell
'''
TEMPORARY UTILS ADDED HERE UNTIL THE NEXT fastcore RELEASE
'''
from fastcore.imports import *
from fastcore.foundation import *
from fastcore.basics import *
from functools import wraps

from typing import Iterable,Generator,Sequence,Iterator,List,Set,Dict,Union,Optional

import mimetypes,pickle,random,json,subprocess,shlex,bz2,gzip,zipfile,tarfile
import imghdr,struct,distutils.util,tempfile,time,string,collections,shutil
from copy import copy
from contextlib import contextmanager,ExitStack
from pdb import set_trace
from datetime import datetime, timezone
from timeit import default_timer
from fnmatch import fnmatch

def globtastic(
    path:Union[Path,str], # path to start searching
    recursive:bool=True, # search subfolders
    symlinks:bool=True, # follow symlinks?
    file_glob:str=None, # Only include files matching glob
    file_re:str=None, # Only include files matching regex
    folder_re:str=None, # Only enter folders matching regex
    skip_file_glob:str=None, # Skip files matching glob
    skip_file_re:str=None, # Skip files matching regex
    skip_folder_re:str=None # Skip folders matching regex
)->L: # Paths to matched files
    "A more powerful `glob`, including regex matches, symlink handling, and skip parameters"
    path = Path(path)
    if path.is_file(): return L([path])
    if not recursive: skip_folder_re='.'
    file_re,folder_re = compile_re(file_re),compile_re(folder_re)
    skip_file_re,skip_folder_re = compile_re(skip_file_re),compile_re(skip_folder_re)
    def _keep_file(root, name):
        return (not file_glob or fnmatch(name, file_glob)) and (
                not file_re or file_re.search(name)) and (
                not skip_file_glob or not fnmatch(name, skip_file_glob)) and (
                not skip_file_re or not skip_file_re.search(name))
    def _keep_folder(root, name):
        return (not folder_re or folder_re.search(name)) and (
            not skip_folder_re or not skip_folder_re.search(name))
    return L(walk(path, symlinks=symlinks, keep_file=_keep_file, keep_folder=_keep_folder))

# Cell
def compile_re(pat):
    "Compile `pat` if it's not None"
    return None if pat is None else re.compile(pat)

def walk(
    path:(Path,str), # path to start searching
    symlinks:bool=True, # follow symlinks?
    keep_file:callable=noop, # function that returns True for wanted files
    keep_folder:callable=noop, # function that returns True for folders to enter
    func:callable=os.path.join # function to apply to each matched file
): # Generator of `func` applied to matched files
    "Generator version of `os.walk`, using functions to filter files and folders"
    for root,dirs,files in os.walk(path, followlinks=symlinks):
        yield from (func(root, name) for name in files if keep_file(root,name))
        for name in copy(dirs):
            if not keep_folder(root,name): dirs.remove(name)

# Cell
kaggle_braintumor_meta_cols = ['SpecificCharacterSet','ImageType','SOPClassUID',
             'SOPInstanceUID','AccessionNumber','Modality', 'SeriesDescription',
             'PatientID', 'MRAcquisitionType', 'SliceThickness',
             'EchoTime', 'NumberOfAverages', 'ImagingFrequency', 'ImagedNucleus',
             'MagneticFieldStrength', 'SpacingBetweenSlices',
             'EchoTrainLength', 'PercentSampling', 'PercentPhaseFieldOfView',
             'PixelBandwidth', 'TriggerWindow', 'ReconstructionDiameter', 'AcquisitionMatrix',
             'FlipAngle', 'SAR', 'PatientPosition',
             'StudyInstanceUID', 'SeriesInstanceUID', 'SeriesNumber', 'InstanceNumber',
             'ImagePositionPatient', 'ImageOrientationPatient', 'Laterality',
             'PositionReferenceIndicator', 'SliceLocation', 'InStackPositionNumber',
             'SamplesPerPixel', 'PhotometricInterpretation', 'Rows', 'Columns', 'PixelSpacing',
             'BitsAllocated', 'BitsStored', 'HighBit', 'PixelRepresentation', 'WindowCenter',
             'WindowWidth', 'RescaleIntercept', 'RescaleSlope', 'RescaleType']

# Cell
def get_dicom_metadata(path_to_dicom_file, meta_cols):
    """
    Returns the metadata of a single dicom file as a dictionary.

    Params:
        path_to_dicom_file: path to the dicom file
        meta_cols: list of metadata columns to extract
    """
    dicom_object = pydicom.dcmread(path_to_dicom_file)

    col_dict_train = dict()
    for col in meta_cols:
        try:
            col_dict_train[col] = str(getattr(dicom_object, col))
        except AttributeError:
            col_dict_train[col] = "NaN"

    return col_dict_train

# Cell
def get_patient_id(patient_id):
    """
    Returns the correct patient id of a dicom file.

    Params:
        patient_id: patient id of the dicom file
    """
    if patient_id < 10:
        return '0000'+str(patient_id)
    elif patient_id >= 10 and patient_id < 100:
        return '000'+str(patient_id)
    elif patient_id >= 100 and patient_id < 1000:
        return '00'+str(patient_id)
    else:
        return '0'+str(patient_id)

# Cell
def get_patient_BraTS21ID_path(row, path_type):
    patient_id = get_patient_id(int(row.BraTS21ID))
    return f'{path_type}/{patient_id}/'

# Cell
def get_all_dicom_metadata(df, meta_cols:list, scan_types:list=['FLAIR', 'T1w', 'T1wCE', 'T2w']):
    """
    Retrieve metadata for each BraTS21ID and return as a dataframe.

    Params:
        df: dataframe with patient folder ids and BraTS21IDs
        meta_cols: list of metadata columns to extract
        scan_types: list of strings of mdedical scan types, default: ['FLAIR', 'T1w', 'T1wCE', 'T2w']
    """
    meta_cols_dict = []
    for i in tqdm(range(len(df))):
        row = df.iloc[i]
        path = Path(row.path)
        for scan_type in scan_types:
            dcm_file_paths = globtastic(path / scan_type, file_glob='*.*dcm*')
            for pth in dcm_file_paths:
                dicom_metadata = get_dicom_metadata(pth, meta_cols)
                dicom_metadata['scan_type'] = scan_type
                dicom_metadata['id'] = row.BraTS21ID
                meta_cols_dict.append(dicom_metadata)

    return pd.DataFrame(meta_cols_dict)

# Cell
def get_image_plane(data):
    '''
    Returns the MRI's plane from the dicom data.

    Params:
        data: dictionary of dicom metadata

    '''
    x1,y1,_,x2,y2,_ = [round(j) for j in ast.literal_eval(data.ImageOrientationPatient)]
    cords = [x1,y1,x2,y2]

    if cords == [1,0,0,0]:
        return 'coronal'
    if cords == [1,0,0,1]:
        return 'axial'
    if cords == [0,1,0,0]:
        return 'sagittal'