"""Utilities for creating Sankey plots of pystv results."""
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

NODE_PALETTE = px.colors.qualitative.Dark2
LINK_PALETTE = px.colors.qualitative.Set2


def result_to_sankey_data(
    race_result,
    show_exhausted=False,
    node_palette=NODE_PALETTE,
    link_palette=LINK_PALETTE,
):
    num_rounds = len(race_result.rounds)

    labels = ["<exhausted>"] + race_result.metadata.names
    num_cands = len(labels)

    node_color = [node_palette[idx % len(node_palette)] for idx in range(num_cands)]
    df_nodes = pd.DataFrame({"label": labels, "color": node_color})
    df_nodes = pd.concat(
        [df_nodes] * num_rounds, keys=range(num_rounds), names=["round", "id"]
    )
    df_nodes = df_nodes.reset_index()

    data_links = []
    for rnd in range(num_rounds - 1):
        round_result = race_result.rounds[rnd]
        for src in range(num_cands):
            color = link_palette[src % len(link_palette)]
            src_left = round_result.count[src]
            if src in round_result.transfers:
                for tgt, diff in round_result.transfers[src].items():
                    data_links.append((rnd, src, tgt, diff, color))
                    src_left -= diff
            if src_left > 0:
                data_links.append((rnd, src, src, src_left, color))

    df_links = pd.DataFrame(
        data_links, columns=["round", "source", "target", "value", "color"]
    )
    df_links["source"] += df_links["round"] * num_cands
    df_links["target"] += (df_links["round"] + 1) * num_cands

    if not show_exhausted:
        mask = (df_nodes["id"] % num_cands) == 0
        df_nodes.loc[mask, "color"] = "rgba(0,0,0,0)"
        df_nodes.loc[mask, "label"] = ""

        mask = (df_links["target"] % num_cands) == 0
        df_links.loc[mask, "color"] = "rgba(0,0,0,0)"

    return df_nodes, df_links


def create_sankey_fig(df_nodes, df_links):
    sankey = go.Sankey(
        node=dict(
            thickness=10,
            line={"width": 0},
            label=df_nodes["label"],
            color=df_nodes["color"],
            customdata=df_nodes["round"],
            hovertemplate=(
                "%{label} has %{value:.0d} votes "
                "in round %{customdata}<extra></extra>"
            ),
        ),
        link=dict(
            source=df_links["source"],
            target=df_links["target"],
            value=df_links["value"],
            color=df_links["color"],
            # For display, add 2 to round:
            # - add 1 because it is the next round that the transfers are counted
            # - add 1 to start indexing from 1 (real people do not like 0-indexing)
            customdata=df_links["round"] + 2,
            hovertemplate=(
                "%{value:.0d} votes transferred<br />"
                "from %{source.label} to %{target.label}<br />"
                "in round %{customdata}<extra></extra>"
            ),
        ),
        visible=True,
    )

    fig = go.Figure(data=[sankey])
    fig.update_layout(
        margin=dict(l=0, t=0),
    )
    return fig
