"""Testing code generated by nbdev in unpackai/tabular/data.py"""
# Generated automatically from notebook nbs/35_tabular_data.ipynb

from unpackai.tabular.data import *

# Test Cell
import pandas as pd
from pathlib import Path
import pytest
import sys

try:
    root_dir = Path(".")
    df = pd.read_csv(root_dir / "test" / "tabular" / "california_housing.csv")
except FileNotFoundError as e:
    root_dir = Path("../")
    df = pd.read_csv(root_dir / "test" / "tabular" / "california_housing.csv")

# Test Cell
def test_no_missing_value_basic():
    from copy import deepcopy

    df2 = deepcopy(df).head(100)

    # create missing value manually
    for i, key in enumerate(df2):
        df2[key] = list(range(10 * (i + 1))) + [None] * (9 - i) * 10
    new_df = no_missing_values(df2)
    assert (
        list(new_df.columns) == list(df2.columns)[3:]
    ), f"""
        Incorrect column filtering
        based on missing value: {list(new_df.columns)}"""


# Test Cell
@pytest.mark.skipif(sys.platform != "linux", reason="plot test only for linux")
def test_plot_hist_run_through():
    """
    can plot_hist outpout an image
    """
    plot_hist(df.head(100))


# Test Cell
# hide
def repeat_cols(df, n=2):
    df2 = df.copy()
    for feat in df2.columns:
        for i in range(1, n):
            df2[f"{feat}_{i}"] = df2[feat]
    return df2


def add_cat_feat(df, fname):
    import string

    s = string.ascii_lowercase
    vals = [s[i % 20] for i in range(len(df))]
    df[fname] = vals


# Test Cell
# hide
@pytest.mark.skipif(sys.platform != "linux", reason="plot test only for linux")
def test_return_df_true():
    #   Setup
    #   Run
    corr = plot_feat_correlations(df.head(100), return_df=True)
    #   Verify
    assert isinstance(
        corr, pd.DataFrame
    ), f"should return a dataframe and got {type(corr)}"


@pytest.mark.skipif(sys.platform != "linux", reason="plot test only for linux")
def test_return_df_false():
    #   Setup
    #   Run
    corr = plot_feat_correlations(df.head(100), return_df=False)
    #   Verify
    assert corr is None, f"should return None and got {corr}"


# Test Cell
# hide
@pytest.mark.skipif(sys.platform != "linux", reason="plot test only for linux")
def test_non_numericals_dropped():
    #    Setup
    df_test = df.head(100).copy()
    cat_fnames = [f"categorical_{i}" for i in range(1, 6)]
    for fname in cat_fnames:
        add_cat_feat(df_test, fname)

    #   Run
    corr = plot_feat_correlations(df_test, return_df=True)

    #   Verify
    assert corr.shape[1] == df_test.shape[1] - 5
    assert "categorical" not in " ".join(corr.columns)
