# AUTOGENERATED! DO NOT EDIT! File to edit: notebooks/00_core.ipynb (unless otherwise specified).

__all__ = ['download_dataset', 'display_large', 'rf_feature_importance', 'plot_feature_importance']

# Cell
import pandas as pd
from nbdev.showdoc import *
import os
import gdown
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Cell
def download_dataset(dataset_name: str):
    """Download datasets from Google Drive."""

    name_to_id = {
        "susy.csv.gz": "1rnR1v-BkMOtzV80R7jFyU1cwO3fGYrQs",
        "susy.feather": "1PxCruwO42GV7FKtwZDXah7iGjDib7YPM",
        "susy_train.feather": "1ezeCZycZ3BrEh-qOLiSJF40YowYEbbTH",
        "susy_test.feather": "1UM8sheb4jzQa16haG6HnVbpJCxZwN2yE",
        "susy_sample.feather": "1l4x_uBeup4eciLDK4YjnfY_G8yTpXLkP",
    }

    path = "../data/"
    os.makedirs(path, exist_ok=True)
    gdrive_path = "https://drive.google.com/uc?id="
    if dataset_name in name_to_id:
        if os.path.exists(path + dataset_name):
            print(
                f"Dataset already exists at '{path + dataset_name}' and is not downloaded again."
            )
            return
        try:
            file_url = gdrive_path + name_to_id[dataset_name]
            gdown.download(file_url, path + dataset_name, quiet=True)
        except Exception as e:
            print("Something went wrong during the download! Try again.")
            raise e
        print(f"Download of {dataset_name} dataset complete.")
    else:
        raise KeyError("File not on Google Drive.")

# Cell
def display_large(df):
    """Displays up to 1000 columns and rows of pandas.DataFrame or pandas.Series objects."""
    with pd.option_context("display.max_rows", 1000, "display.max_columns", 1000):
        display(df)

# Cell
def rf_feature_importance(fitted_model, df):
    "Creates a pandas.Dataframe of a Random Forest's feature importance per column."
    return pd.DataFrame(
        {"Column": df.columns, "Importance": fitted_model.feature_importances_}
    ).sort_values("Importance", ascending=False)

# Cell
def plot_feature_importance(feature_importance):
    fig, ax = plt.subplots(figsize=(12, 8))
    return sns.barplot(y="Column", x="Importance", data=feature_importance, color="b")