# AUTOGENERATED! DO NOT EDIT! File to edit: 00_core.ipynb (unless otherwise specified).

__all__ = ['is_default_index', 'is_multiindex_row_df', 'is_multiindex_col_df', 'FastDataDataframeUtilities',
           'split_list_to_columns', 'pivot_table', 'FastDataSeriesUtilities', 'extract_number_from_string',
           'bin_column', 'fill_empty', 'replace_based_on_condition']

# Cell
import pandas as pd
import numpy as np
import re
from fastcore.all import *

# Cell
def is_default_index(df):
    # Check if the index is the same as the default index. We use the name as a proxy
    check_index = ((df.index == pd.RangeIndex(start=0,stop=df.shape[0], step=1)).all())
    return check_index

# Cell
def is_multiindex_row_df(df):
    if isinstance(df, pd.core.frame.DataFrame):
        if isinstance(df.index, pd.core.indexes.multi.MultiIndex):
            return True
    return False

# Cell
def is_multiindex_col_df(df):
    if isinstance(df, pd.core.frame.DataFrame):
        if isinstance(df.columns, pd.core.indexes.multi.MultiIndex):
            return True
    return False

# Cell
@pd.api.extensions.register_dataframe_accessor('fdt')
class FastDataDataframeUtilities:
    def __init__(self, pandas_obj):
        self._obj = pandas_obj

    def remove_indexes(self,axis='all'):
        df = self._obj.copy()
        if is_multiindex_col_df(df) and axis in ['columns','all']:
            df.columns = df.columns.map(lambda x: '_'.join([str(i) for i in x]))
        if ((is_multiindex_row_df(df)) or (is_default_index(df) == False)) and axis in ['index','all']:
            df = df.reset_index()
        return df

# Cell
@patch_to(FastDataDataframeUtilities)
def split_list_to_columns(self, column, separator=',', list_marker='na', split_type='unique'):
    df = self._obj.copy()

    type_of_first_not_null_element = type(df[column][df[column].notnull()][0])

    # First check if it is already a list or it needs pre-processing
    if(type_of_first_not_null_element != list):
        # If not, let's start processing it
        # First we process the surrounding brackets, if they exist
        if list_marker != 'na':
            if list_marker == 'square_brackets':
                df[column] = df[column].str.replace(r"[\[\]']","")
            elif list_marker == 'parentheses':
                df[column] = df[column].str.replace(r'([()])','')
        # Then we process the separator only if we take the unique
        if split_type == 'unique':
            df[column] = df[column].str.split(separator)

    if split_type == 'unique':
        exploded = pd.get_dummies(df[column].apply(pd.Series).stack(dropna=False)).sum(level=0)
        if '' in exploded.columns:
            exploded = exploded.rename(columns={'':'blank'})
        minus_pivoted = df.drop(column,axis=1)
        result = pd.concat([minus_pivoted,exploded], axis=1)
        result = result.fillna(0)
    elif split_type == 'order':
        exploded = df[column].str.split(separator, expand=True)
        minus_pivoted = df.drop(column,axis=1)
        result = pd.concat([minus_pivoted,exploded], axis=1)

    result.columns = result.columns.map(str)
    return result

# Cell
@patch_to(FastDataDataframeUtilities)
def pivot_table(self, index_type, **kwargs):
    df = self._obj.copy()

    df = df.pivot_table(**kwargs)

    if index_type == 'flat':
        df = df.fdt.flatten_multiindex(axis='all')

    return df

# Cell
@pd.api.extensions.register_series_accessor('fdt')
class FastDataSeriesUtilities:
    def __init__(self, pandas_obj):
        self._obj = pandas_obj

    def find_between_text(self, start_string, end_string):
        series = self._obj
        search_expr = start_string + '(.*)' + end_string
        series = series.str.extract(search_expr)
        return series

# Cell
@patch_to(FastDataSeriesUtilities)
def extract_number_from_string(self, dtype):
    series = self._obj
    series = series.str.extract('(\d+)')
    series = series.astype(dtype)
    return series

# Cell
@patch_to(FastDataSeriesUtilities)
def bin_column(self, **kwargs):
    series = self._obj
    #Parameters
    p = kwargs
    if p['mode'] == 'size':
        interval_range = pd.interval_range(start=p['start'], freq=p['size'], end=p['end'])
        #print(interval_range)
        series = pd.cut(series, bins=interval_range)
        return series

    if p['mode'] == 'number':
        series = pd.cut(series, bins=p['bin_number'])
        return series

    if p['mode'] == 'quantiles':
        series = pd.qcut(series, q=p['quantiles'])
        return series

    if p['mode'] == 'custom':
        interval_range=pd.IntervalIndex.from_breaks(p['breaks'], closed=p['closed'])
        series = pd.cut(series, bins= interval_range)
        return series

# Cell
@patch_to(FastDataSeriesUtilities)
def fill_empty(self, **kwargs):
    series = self._obj
    p = kwargs
    if p['mode'] == 'function':
        if p['function'] == 'ffill':
            series = series.fillna(method='ffill')
        elif p['function'] == 'bfill':
            series = series.fillna(method='bfill')
        elif p['function'] == 'mean':
            series = series.fillna(series.mean())
        elif p['function'] == 'most_frequent':
            series = series.fillna(series.mode()[0])
    elif p['mode'] == 'value':
        series = series.fillna(p['value'])

    return series

# Cell
@patch_to(FastDataSeriesUtilities)
def replace_based_on_condition(self, cond, when, replace_with=np.NaN):
    series = self._obj

    if when == True:
        series = series.mask(cond=cond, other=replace_with)
    elif when == False:
        series = series.where(cond=cond, other=replace_with)

    return series