# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/07_dataset_info.ipynb.

# %% auto 0
__all__ = ['MedDataset']

# %% ../nbs/07_dataset_info.ipynb 2
from .vision_core import *

import multiprocessing as mp
from functools import partial
import pandas as pd
import numpy as np
import glob

# %% ../nbs/07_dataset_info.ipynb 4
class MedDataset():
    '''A class to extract and present information about the dataset.'''

    def __init__(self, path=None, # Path to the image folder
                 postfix:str='', # Specify the file type if there are different files in the folder
                 img_list:list=None, # A list with image path
                 reorder:bool=False, # Whether to reorder the data to be closest to canonical (RAS+) orientation
                 dtype:(MedImage, MedMask)=MedImage, # Load data as datatype
                 max_workers:int=1 #  The number of worker threads
                ):
        '''Constructs all the necessary attributes for the MedDataset object.'''

        self.path = path
        self.postfix = postfix
        self.img_list = img_list
        self.reorder = reorder
        self.dtype = dtype
        self.max_workers = max_workers
        self.df = self._create_data_frame()

    def _create_data_frame(self):
        '''Private method that returns a dataframe with information about the dataset

        Returns:
            DataFrame: A DataFrame with information about the dataset.
        '''

        if self.path:
            self.img_list = glob.glob(f'{self.path}/*{self.postfix}*')
            if not self.img_list: print('Could not find images. Check the image path')

        pool = mp.Pool(self.max_workers)
        data_info_dict = pool.map(self._get_data_info, self.img_list)

        df = pd.DataFrame(data_info_dict)
        if df.orientation.nunique() > 1: print('The volumes in this dataset have different orientations. Recommended to pass in the argument reorder=True when creating a MedDataset object for this dataset')
        return df

    def summary(self):
        '''Summary DataFrame of the dataset with example path for similar data.'''

        columns = ['dim_0', 'dim_1', 'dim_2', 'voxel_0', 'voxel_1', 'voxel_2', 'orientation']
        return self.df.groupby(columns,as_index=False).agg(example_path=('path', 'min'), total=('path', 'size')).sort_values('total', ascending=False)

    def suggestion(self):
        '''Voxel value that appears most often in dim_0, dim_1 and dim_2, and wheter the data should be reoriented.'''
        resample = [self.df.voxel_0.mode()[0], self.df.voxel_1.mode()[0], self.df.voxel_2.mode()[0]]

        return resample, self.reorder

    def _get_data_info(self, fn:str):
        '''Private method to collect information about an image file.

        Args:
            fn: Image file path.

        Returns:
            dict: A dictionary with information about the image file
        '''

        o,_ = med_img_reader(fn, dtype=self.dtype, reorder=self.reorder, only_tensor=False)

        info_dict = {'path': fn,  'dim_0': o.shape[1],  'dim_1': o.shape[2],  'dim_2' :o.shape[3],
                     'voxel_0': round(o.spacing[0], 4), 'voxel_1': round(o.spacing[1], 4), 'voxel_2': round(o.spacing[2], 4),
                     'orientation': f'{"".join(o.orientation)}+'}

        if self.dtype is MedMask:
            mask_labels_dict = o.count_labels()
            mask_labels_dict = {f'voxel_count_{int(key)}': val for key, val in mask_labels_dict.items()}
            info_dict.update(mask_labels_dict)

        return info_dict

    def get_largest_img_size(self,
                             resample:list=None # A list with voxel spacing [dim_0, dim_1, dim_2]
                            ) -> list:
        '''Get the largest image size in the dataset.'''
        if resample is not None: 
            
            org_voxels = self.df[["voxel_0", "voxel_1", 'voxel_2']].values
            org_dims = self.df[["dim_0", "dim_1", 'dim_2']].values
            
            ratio = org_voxels/resample
            new_dims = (org_dims * ratio).T
            return [new_dims[0].max().round(), new_dims[1].max().round(), new_dims[2].max().round()]
        
        else: return [df.dim_0.max(), df.dim_1.max(), df.dim_2.max()]
