# AUTOGENERATED! DO NOT EDIT! File to edit: ../../notebooks/04_visualization.ipynb.

# %% auto 0
__all__ = ['plotly_confusion_matrix', 'get_classification_report', 'get_confusion_matrix']

# %% ../../notebooks/04_visualization.ipynb 1
import plotly.figure_factory as ff
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

# %% ../../notebooks/04_visualization.ipynb 2
def plotly_confusion_matrix(labels, y, _y):

    l = [0 for _ in range(len(labels))]
    z = [[0 for _ in range(len(labels))] for _ in range(len(labels))]
    h = [[0 for _ in range(len(labels))] for _ in range(len(labels))]

    for i, j in zip(y, _y):
        z[j][i] += 1
        l[i] += 1

    x = labels.copy()
    y = labels.copy()

    z_labels = [[str(col) if col != 0 else "" for col in row] for row in z]

    for i in range(len(labels)):
        for j in range(len(labels)):
            if i == j:
                h[i][j] = (
                    "Correctly predicted "
                    + str(z[i][j])
                    + " out of "
                    + str(l[i])
                    + " "
                    + labels[i]
                    + " with accuracy "
                    + str(z[i][j] / l[i])
                )
            else:
                if z[j][i] == 0:
                    h[j][i] = ""
                else:
                    h[j][i] = (
                        "Incorrectly predicted "
                        + str(z[j][i])
                        + " out of "
                        + str(l[i])
                        + " "
                        + labels[i]
                        + " as "
                        + labels[j]
                    )

    fig = ff.create_annotated_heatmap(
        z,
        x=x,
        y=y,
        text=h,
        annotation_text=z_labels,
        hoverinfo="text",
        colorscale="Blues",
    )

    fig.update_layout(width=850, height=550)
    fig.update_layout(margin=dict(t=100, l=200))

    fig.add_annotation(
        dict(
            font=dict(color="#094973", size=16),
            x=0.5,
            y=-0.10,
            showarrow=False,
            text="True Class",
            xref="paper",
            yref="paper",
        )
    )
    fig.add_annotation(
        dict(
            font=dict(color="#094973", size=16),
            x=-0.17,
            y=0.5,
            showarrow=False,
            text="Predicted Class",
            textangle=-90,
            xref="paper",
            yref="paper",
        )
    )

    fig.show()
    return fig


# %% ../../notebooks/04_visualization.ipynb 3
def get_classification_report(true_categories, predicted_categories, labels):
    # Classification Report
    cl_report = classification_report(
        true_categories,
        predicted_categories,
        labels=[i for i in range(7)],  # TODO: Convert to class, and add num_classes instead of 7 from cfg
        target_names=labels,
        output_dict=False,
    )

    print(f"\nClassification Report\n{cl_report}")
    return cl_report

# %% ../../notebooks/04_visualization.ipynb 4
def get_confusion_matrix(true_categories, predicted_categories, labels):
    # Confusion Matrix
    
    # Create confusion matrix and normalizes it over predicted (columns)
    result = confusion_matrix(true_categories, predicted_categories, normalize="pred")
    disp = ConfusionMatrixDisplay(confusion_matrix=result, display_labels=labels)
    disp.plot()
    plt.xticks(rotation=35)
    plt.savefig("confusion_matrix.png")
    plt.close()

    return result
