# (c) Kevin Dunn, 2010-2021. MIT License. Based on own private work over the years.

# Built-in libraries
import json
from typing import Dict

import plotly.graph_objects as go
from pydantic import BaseModel, validator


def plot_pre_checks(model, pc_horiz, pc_vert, pc_depth) -> bool:
    assert (
        0 < pc_horiz <= model.A
    ), f"The model has {model.A} components. Ensure that 1 <= pc_horiz<={model.A}."
    assert (
        0 < pc_vert <= model.A
    ), f"The model has {model.A} components. Ensure that 1 <= pc_vert<={model.A}."
    assert (
        -1 <= pc_depth <= model.A
    ), f"The model has {model.A} components. Ensure that 1 <= pc_depth<={model.A}."
    assert (
        len(set([pc_horiz, pc_vert, pc_depth])) == 3
    ), "Specify distinct components for each axis"

    return True


def score_plot(
    model,
    pc_horiz: int = 1,
    pc_vert: int = 2,
    pc_depth: int = -1,
    items_to_highlight: Dict[str, list] = None,
    settings: Dict = None,
    fig=None,
) -> go.Figure:
    """Generates a 2-dimensional score plot for the given latent variable model.

    Parameters
    ----------
    model : MVmodel object (PCA, or PLS)
        A latent variable model generated by this library.
    pc_horiz : int, optional
        Which component to plot on the horizontal axis, by default 1 (the first component)
    pc_vert : int, optional
        Which component to plot on the vertical axis, by default 2 (the second component)
    pc_depth : int, optional
        If pc_depth >= 1, then a 3D score plot is generated, with this component on the 3rd axis
    items_to_highlight : dict, optional
        keys:   an string which can be json.loads(...) and turns into a Plotly line specifier.
        values: a list of identifiers for the items to highlight [index names]
        For example:
            items_to_highlight = {'{"color": "red", "symbol": "cross"}': items_in_red}

            will ensure the subset of the index listed in `items_in_red` in that colour and shape.

    settings : dict
        Default settings are = {
            "show_ellipse": True [bool],
                Should the Hotelling's T2 ellipse be added

            "ellipse_conf_level": 0.95 [float]
                If the ellipse is added, which confidence level is used. A number < 1.00.

            "title": f"Score plot of ... "
                Overall plot title

            "show_labels": False,
                Adds a label for each observation. Labels are always available in the hover.

            "show_legend": True,
                Shows a clickable legend (allows to turn the ellipse(s) on/off)

            "html_image_height": 500,
                in pixels

            "html_aspect_ratio_w_over_h": 16/9,
                sets the image width, as a ratio of the height

        }
    """
    plot_pre_checks(model, pc_horiz, pc_vert, pc_depth)
    margin_dict: Dict = dict(l=10, r=10, b=5, t=80)  # Defaults: l=80, r=80, t=100, b=80

    class Settings(BaseModel):
        show_ellipse: bool = True
        ellipse_conf_level: float = 0.95  # TODO: check constraint
        title: str = f"Score plot of component {pc_horiz} vs component {pc_vert}" + (
            f" vs component {pc_depth}" if pc_depth > 0 else ""
        )
        show_labels: bool = False  # TODO
        show_legend: bool = True
        html_image_height: float = 500.0
        html_aspect_ratio_w_over_h: float = 16 / 9.0

        @validator("ellipse_conf_level")
        def check_ellipse_conf_level(cls, v):
            if v >= 1:
                raise ValueError("0.0 < `ellipse_conf_level` < 1.0")
            if v <= 0:
                raise ValueError("0.0 < `ellipse_conf_level` < 1.0")
            return v

    if settings:
        setdict = Settings(**settings).dict()
    else:
        setdict = Settings().dict()
    if fig is None:
        fig = go.Figure()

    name = "X-space scores [T]"
    fig.update_layout(
        xaxis_title_text=f"PC {pc_horiz}", yaxis_title_text=f"PC {pc_vert}"
    )

    highlights: Dict[str, list] = {}
    default_index = model.x_scores.index
    if items_to_highlight is not None:
        highlights = items_to_highlight.copy()
        for key, items in items_to_highlight.items():
            highlights[key] = list(set(items) & set(default_index))
            default_index = (set(default_index) ^ set(highlights[key])) & set(
                default_index
            )

    # Ensure it is back to a list
    default_index = list(default_index)

    # 3D plot
    if pc_depth >= 1:
        fig.add_trace(
            go.Scatter3d(
                x=model.x_scores.loc[default_index, pc_horiz],
                y=model.x_scores.loc[default_index, pc_vert],
                z=model.x_scores.loc[default_index, pc_depth],
                name=name,
                mode="markers+text" if setdict["show_labels"] else "markers",
                marker=dict(
                    color="darkblue",
                    symbol="circle",
                ),
                text=list(default_index),
                textposition="top center",
            )
        )
        # Items to highlight, if any
        for key, index in highlights.items():
            styling = json.loads(key)
            fig.add_trace(
                go.Scatter3d(
                    x=model.x_scores.loc[index, pc_horiz],
                    y=model.x_scores.loc[index, pc_vert],
                    z=model.x_scores.loc[index, pc_depth],
                    name=name,
                    mode="markers+text" if setdict["show_labels"] else "markers",
                    marker=styling,
                    text=list(index),
                    textposition="top center",
                )
            )
    else:
        # Regular 2D plot
        fig.add_trace(
            go.Scatter(
                x=model.x_scores.loc[default_index, pc_horiz],
                y=model.x_scores.loc[default_index, pc_vert],
                name=name,
                mode="markers+text" if setdict["show_labels"] else "markers",
                marker=dict(
                    color="darkblue",
                    symbol="circle",
                    size=7,
                ),
                text=default_index,
                textposition="top center",
            )
        )
        # Items to highlight, if any
        for key, index in highlights.items():
            styling = json.loads(key)
            fig.add_trace(
                go.Scatter(
                    x=model.x_scores.loc[index, pc_horiz],
                    y=model.x_scores.loc[index, pc_vert],
                    name=name,
                    mode="markers+text" if setdict["show_labels"] else "markers",
                    marker=styling,
                    text=list(index),
                    textposition="top center",
                )
            )
        if setdict["show_ellipse"]:
            ellipse = model.ellipse_coordinates(
                score_horiz=pc_horiz,
                score_vert=pc_vert,
                T2_limit_conf_level=setdict["ellipse_conf_level"],
            )
            fig.add_hline(y=0, line_color="black")
            fig.add_vline(x=0, line_color="black")
            fig.add_trace(
                go.Scatter(
                    x=ellipse[0],
                    y=ellipse[1],
                    name=f"Hotelling's T^2 [{setdict['ellipse_conf_level']*100:.4g}%]",
                    mode="lines",
                    line=dict(
                        color="red",
                        width=2,
                    ),
                )
            )

    fig.update_layout(
        title_text=setdict["title"],
        margin=margin_dict,
        hovermode="closest",
        showlegend=setdict["show_legend"],
        legend=dict(
            orientation="h",
            traceorder="normal",
            font=dict(family="sans-serif", size=12, color="#000"),
            bordercolor="#DDDDDD",
            borderwidth=1,
        ),
        autosize=False,
        xaxis=dict(
            gridwidth=1,
            mirror=True,
            showspikes=True,
            visible=True,
        ),
        yaxis=dict(
            gridwidth=2,
            type="linear",
            autorange=True,
            showspikes=True,
            visible=True,
            showline=True,
            side="left",
        ),
        width=setdict["html_aspect_ratio_w_over_h"] * setdict["html_image_height"],
        height=setdict["html_image_height"],
    )
    if pc_depth >= 1:
        fig.update_layout(
            scene=dict(
                xaxis=fig.to_dict()["layout"]["xaxis"],
                yaxis=fig.to_dict()["layout"]["xaxis"],
                zaxis=dict(
                    title_text=f"PC {pc_depth}",
                    mirror=True,
                    showspikes=True,
                    visible=True,
                    gridwidth=1,
                ),
            ),
        )
    return fig


def loadings_plot(
    model, loadings_type="p", pc_horiz=1, pc_vert=2, settings: Dict = None, fig=None
) -> go.Figure:
    """Generates a 2-dimensional loadings for the given latent variable model.

    Parameters
    ----------
    model : MVmodel object (PCA, or PLS)
        A latent variable model generated by this library.

    loadings_type : str, optional
        A choice of the following:
            'p' : (default for PCA) : the P (projection) loadings: only option possible for PCA
            'w' : the W loadings: Suitable for PLS
            'w*' : (default for PLS) the W* (or R) loadings: Suitable for PLS
            'w*c' : the W* (from X-space) with C loadings from the Y-space: Suitable for PLS
            'c' : the C loadings from the Y-space: Suitable for PLS

        For PCA model any other choice besides 'p' will be ignored.

    pc_horiz : int, optional
        Which component to plot on the horizontal axis, by default 1 (the first component)
    pc_vert : int, optional
        Which component to plot on the vertical axis, by default 2 (the second component)
    settings : dict
        Default settings are = {
            "title": f"Loadings plot of component {pc_horiz} vs component {pc_vert}"
                Overall plot title

            "show_labels": True,
                Adds a label for each column. Labels are always available in the hover.

            "html_image_height": 500,
                in pixels

            "html_aspect_ratio_w_over_h": 16/9,
                sets the image width, as a ratio of the height

        }
    """
    plot_pre_checks(model, pc_horiz, pc_vert, pc_depth=0)
    margin_dict: Dict = dict(l=10, r=10, b=5, t=80)  # Defaults: l=80, r=80, t=100, b=80

    class Settings(BaseModel):
        title: str = (
            f"Loadings plot [{loadings_type.upper()}] of component {pc_horiz} vs "
            f"component {pc_vert}"
        )
        show_labels: bool = True
        html_image_height: float = 500.0
        html_aspect_ratio_w_over_h: float = 16 / 9.0

    if settings:
        setdict = Settings(**settings).dict()
    else:
        setdict = Settings().dict()
    if fig is None:
        fig = go.Figure()

    what = model.x_loadings  # PCA default
    if hasattr(model, "direct_weights"):
        what = model.direct_weights  # PLS default
    extra = None
    if loadings_type.lower() == "p":
        what = model.x_loadings
    if loadings_type.lower() == "w":
        what = model.x_weights
    elif loadings_type.lower() == "w*":
        what = model.direct_weights
    elif loadings_type.lower() == "w*c":
        loadings_type = loadings_type[0:-1]
        what = model.direct_weights
        extra = model.y_loadings
    elif loadings_type.lower() == "c":
        what = model.y_loadings

    fig.add_trace(
        go.Scatter(
            x=what.loc[:, pc_horiz],
            y=what.loc[:, pc_vert],
            name="X-space loadings W*",
            mode="markers+text" if setdict["show_labels"] else "markers",
            marker=dict(
                color="darkblue",
                symbol="circle",
            ),
            marker_size=7,
            text=what.index,
            textposition="top center",
        )
    )
    add_legend = False

    # Note, we have cut off the 'c' from loadings_type
    add_legend = False
    if loadings_type.lower() == "w*" and extra is not None:
        add_legend = True
        fig.add_trace(
            go.Scatter(
                x=extra.loc[:, pc_horiz],
                y=extra.loc[:, pc_vert],
                name="Y-space loadings C",
                mode="markers+text" if setdict["show_labels"] else "markers",
                marker=dict(
                    color="purple",
                    symbol="star",
                ),
                marker_size=7,
                text=extra.index,
                textposition="bottom center",
            )
        )

    fig.update_layout(
        xaxis_title_text=f"PC {pc_horiz}", yaxis_title_text=f"PC {pc_vert}"
    )
    fig.add_hline(y=0, line_color="black")
    fig.add_vline(x=0, line_color="black")
    fig.update_layout(
        title_text=setdict["title"],
        margin=margin_dict,
        hovermode="closest",
        showlegend=add_legend,
        autosize=False,
        xaxis=dict(
            gridwidth=1,
            mirror=True,
            showspikes=True,
            visible=True,
        ),
        yaxis=dict(
            gridwidth=2,
            type="linear",
            autorange=True,
            showspikes=True,
            visible=True,
            showline=True,
            side="left",
        ),
        width=setdict["html_aspect_ratio_w_over_h"] * setdict["html_image_height"],
        height=setdict["html_image_height"],
    )
    return fig


def spe_plot(
    model,
    with_a=-1,
    items_to_highlight: Dict[str, list] = None,
    settings: Dict = None,
    fig=None,
) -> go.Figure:
    """Generates a squared-prediction error (SPE) plot for the given latent variable model using
    `with_a` number of latent variables. The default will use the total number of latent variables
    which have already been fitted.

    Parameters
    ----------
    model : MVmodel object (PCA, or PLS)
        A latent variable model generated by this library.
    with_a : int, optional
        Uses this many number of latent variables, and therefore shows the SPE after this number of
        model components. By default the total number of components fitted will be used.
    items_to_highlight : dict, optional
        keys:   an string which can be json.loads(...) and turns into a Plotly line specifier.
        values: a list of identifiers for the items to highlight [index names]
        For example:
            items_to_highlight = {'{"color": "red", "symbol": "cross"}': items_in_red}

            will ensure the subset of the index listed in `items_in_red` in that colour and shape.

    settings : dict
        Default settings are = {
            "show_limit": True [bool],
                Should the SPE limit be plotted.

            "conf_level": 0.95 [float]
                If the limit line is added, which confidence level is used. Number < 1.00.

            "title": f"Squared prediction error plot after fitting {with_a} components,
                       with the {conf_level*100}% confidence limit"
                Overall plot title

            "default_marker": optional, [dict]
                dict(color="darkblue", symbol="circle", size=7)

            "show_labels": False,
                Adds a label for each observation. Labels are always available in the hover.

            "show_legend": False,
                Shows a clickable legend (allows to turn the limit on/off)

            "html_image_height": 500,
                Image height, in pixels.

            "html_aspect_ratio_w_over_h": 16/9,
                Sets the image width, as a ratio of the height.

        }
    """
    # TO CONSIDER: allow a setting `as_line`: which connects the points with line segments
    margin_dict: Dict = dict(l=10, r=10, b=5, t=80)  # Defaults: l=80, r=80, t=100, b=80

    if with_a < 0:
        # Get the actual name of the last column in the model if negative indexing is used
        with_a = model.squared_prediction_error.columns[with_a]
    elif with_a == 0:
        assert False, "`with_a` must be >= 1, or specified with negative indexing"

    assert with_a <= model.A, "`with_a` must be <= the number of components fitted"

    class Settings(BaseModel):
        show_limit: bool = True
        conf_level: float = 0.95  # TODO: check constraint < 1
        title: str = (
            "Squared prediction error plot after "
            f"fitting {with_a} component{'s' if with_a > 1 else ''}"
            f", with the {conf_level*100}% confidence limit"
        )
        default_marker: Dict = dict(color="darkblue", symbol="circle", size=7)
        show_labels: bool = False
        show_legend: bool = False
        html_image_height: float = 500.0
        html_aspect_ratio_w_over_h: float = 16 / 9.0

        @validator("conf_level")
        def check_conf_level(cls, v):
            if v >= 1:
                raise ValueError("0.0 < `conf_level` < 1.0")
            if v <= 0:
                raise ValueError("0.0 < `conf_level` < 1.0")
            return v

    if settings:
        setdict = Settings(**settings).dict()
    else:
        setdict = Settings().dict()
    if fig is None:
        fig = go.Figure()

    name = f"SPE values after {with_a} component{'s' if with_a > 1 else ''}"
    highlights: Dict[str, list] = {}
    default_index = model.squared_prediction_error.index
    if items_to_highlight is not None:
        highlights = items_to_highlight.copy()
        for key, items in items_to_highlight.items():
            highlights[key] = list(set(items) & set(default_index))
            default_index = (set(default_index) ^ set(highlights[key])) & set(
                default_index
            )

    # Ensure it is back to a list
    default_index = list(default_index)
    fig.add_trace(
        go.Scatter(
            x=default_index,
            y=model.squared_prediction_error.loc[default_index, with_a],
            name=name,
            mode="markers+text" if setdict["show_labels"] else "markers",
            marker=setdict["default_marker"],
            text=default_index,
            textposition="top center",
            showlegend=setdict["show_legend"],
        )
    )
    # Items to highlight, if any
    for key, index in highlights.items():
        styling = json.loads(key)
        fig.add_trace(
            go.Scatter(
                x=index,
                y=model.squared_prediction_error.loc[index, with_a],
                name=name,
                mode="markers+text" if setdict["show_labels"] else "markers",
                marker=styling,
                text=index,
                textposition="top center",
            )
        )

    limit_SPE_conf_level = model.SPE_limit(conf_level=setdict["conf_level"])
    name = f'{setdict["conf_level"]*100:.3g}% limit'
    fig.add_hline(
        y=limit_SPE_conf_level,
        line_color="red",
        annotation_text=name,
        annotation_position="bottom right",
        name=name,
    )
    fig.add_hline(y=0, line_color="black")
    fig.update_layout(
        title_text=setdict["title"],
        margin=margin_dict,
        hovermode="closest",
        showlegend=setdict["show_legend"],
        legend=dict(
            orientation="h",
            traceorder="normal",
            font=dict(family="sans-serif", size=12, color="#000"),
            bordercolor="#DDDDDD",
            borderwidth=1,
        ),
        autosize=False,
        xaxis=dict(
            gridwidth=1,
            mirror=True,
            showspikes=True,
            visible=True,
        ),
        yaxis=dict(
            title=name,
            gridwidth=2,
            type="linear",
            autorange=True,
            showspikes=True,
            visible=True,
            showline=True,  # show a separating line
            side="left",  # show on the RHS
        ),
        width=setdict["html_aspect_ratio_w_over_h"] * setdict["html_image_height"],
        height=setdict["html_image_height"],
    )
    return fig


def t2_plot(
    model,
    with_a=-1,
    items_to_highlight: Dict[str, list] = None,
    settings: Dict = None,
    fig=None,
) -> go.Figure:
    """Generates a Hotelling's T2 (T^2) plot for the given latent variable model using
    `with_a` number of latent variables. The default will use the total number of latent variables
    which have already been fitted.

    Parameters
    ----------
    model : MVmodel object (PCA, or PLS)
        A latent variable model generated by this library.
    with_a : int, optional
        Uses this many number of latent variables, and therefore shows the SPE after this number of
        model components. By default the total number of components fitted will be used.
    items_to_highlight : dict, optional
        keys:   an string which can be json.loads(...) and turns into a Plotly line specifier.
        values: a list of identifiers for the items to highlight [index names]
        For example:
            items_to_highlight = {'{"color": "red", "symbol": "cross"}': items_in_red}

            will ensure the subset of the index listed in `items_in_red` in that colour and shape.

    settings : dict
        Default settings are = {
            "show_limit": True [bool],
                Should the T2 limit be plotted.

            "conf_level": 0.95 [float]
                If the limit line is added, which confidence level is used. Number < 1.00.

            "title": f"Hotelling's T2 plot after fitting {with_a} components,
                       with the {conf_level*100}% confidence limit""
                Overall plot title

            "default_marker": optional, [dict]
                dict(color="darkblue", symbol="circle", size=7)

            "show_labels": False,
                Adds a label for each observation. Labels are always available in the hover.

            "show_legend": False,
                Shows a clickable legend (allows to turn the limit on/off)

            "html_image_height": 500,
                Image height, in pixels.

            "html_aspect_ratio_w_over_h": 16/9,
                Sets the image width, as a ratio of the height.
        }
    """
    # TO CONSIDER: allow a setting `as_line`: which connects the points with line segments
    margin_dict: Dict = dict(l=10, r=10, b=5, t=80)  # Defaults: l=80, r=80, t=100, b=80

    if with_a < 0:
        with_a = model.Hotellings_T2.columns[with_a]

    # TODO: check `with_a`: what should it plot if `with_a` is zero, or > A?

    class Settings(BaseModel):
        show_limit: bool = True
        conf_level: float = 0.95  # TODO: check constraint < 1
        title: str = (
            f"Hotelling's T2 plot after fitting {with_a} component{'s' if with_a > 1 else ''}"
            f", with the {conf_level*100}% confidence limit"
        )
        default_marker: Dict = dict(color="darkblue", symbol="circle", size=7)
        show_labels: bool = False  # TODO
        show_legend: bool = False
        html_image_height: float = 500.0
        html_aspect_ratio_w_over_h: float = 16 / 9.0

    if settings:
        setdict = Settings(**settings).dict()
    else:
        setdict = Settings().dict()
    if fig is None:
        fig = go.Figure()

    name = f"T2 values after {with_a} component{'s' if with_a > 1 else ''}"
    highlights: Dict[str, list] = {}
    default_index = model.Hotellings_T2.index
    if items_to_highlight is not None:
        highlights = items_to_highlight.copy()
        for key, items in items_to_highlight.items():
            highlights[key] = list(set(items) & set(default_index))
            default_index = (set(default_index) ^ set(highlights[key])) & set(
                default_index
            )

    # Ensure it is back to a list
    default_index = list(default_index)
    fig.add_trace(
        go.Scatter(
            x=default_index,
            y=model.Hotellings_T2.loc[default_index, with_a],
            name=name,
            mode="markers+text" if setdict["show_labels"] else "markers",
            marker=setdict["default_marker"],
            text=default_index,
            textposition="top center",
            showlegend=setdict["show_legend"],
        )
    )
    # Items to highlight, if any
    for key, index in highlights.items():
        styling = json.loads(key)
        fig.add_trace(
            go.Scatter(
                x=index,
                y=model.Hotellings_T2.loc[index, with_a],
                name=name,
                mode="markers+text" if setdict["show_labels"] else "markers",
                marker=styling,
                text=index,
                textposition="top center",
            )
        )

    limit_HT2_conf_level = model.T2_limit(conf_level=setdict["conf_level"])
    name = f'{setdict["conf_level"]*100:.3g}% limit'
    fig.add_hline(
        y=limit_HT2_conf_level,
        line_color="red",
        annotation_text=name,
        annotation_position="bottom right",
        name=name,
    )
    fig.add_hline(y=0, line_color="black")
    fig.update_layout(
        title_text=setdict["title"],
        margin=margin_dict,
        hovermode="closest",
        showlegend=setdict["show_legend"],
        legend=dict(
            orientation="h",
            traceorder="normal",
            font=dict(family="sans-serif", size=12, color="#000"),
            bordercolor="#DDDDDD",
            borderwidth=1,
        ),
        autosize=False,
        xaxis=dict(
            gridwidth=1,
            mirror=True,
            showspikes=True,
            visible=True,
        ),
        yaxis=dict(
            title_text=name,
            gridwidth=2,
            type="linear",
            autorange=True,
            showspikes=True,
            visible=True,
            showline=True,
            side="left",
        ),
        width=setdict["html_aspect_ratio_w_over_h"] * setdict["html_image_height"],
        height=setdict["html_image_height"],
    )
    return fig
