"""Utilities for creating Sankey plots of pystv results."""
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

NODE_PALETTE = px.colors.qualitative.Dark2
LINK_PALETTE = px.colors.qualitative.Set2


def result_to_sankey_data(race_result):
    num_rounds = len(race_result.rounds)

    labels = ["-- exhausted --"] + race_result.metadata.names
    num_cands = len(labels)

    df_nodes = pd.DataFrame({"label": labels}).reset_index()
    df_nodes = df_nodes.rename(columns={"index": "id"})
    df_nodes = pd.concat([df_nodes] * num_rounds).reset_index(drop=True)

    data_links = []
    for rnd in range(num_rounds - 1):
        round_result = race_result.rounds[rnd]
        for src in range(num_cands):
            src_left = round_result.count[src]
            if src in round_result.transfers:
                for tgt, diff in round_result.transfers[src].items():
                    data_links.append((rnd, src, tgt, diff))
                    src_left -= diff
            if src_left > 0:
                data_links.append((rnd, src, src, src_left))

    df_links = pd.DataFrame(data_links, columns=["round", "source", "target", "value"])
    df_links["id"] = df_links["source"]
    df_links["source"] += df_links["round"] * num_cands
    df_links["target"] += (df_links["round"] + 1) * num_cands

    return df_nodes, df_links.drop(columns="round")


def create_sankey_fig(
    df_nodes, df_links, node_palette=NODE_PALETTE, link_palette=LINK_PALETTE
):
    sankey = go.Sankey(
        node=dict(
            thickness=10,
            line=dict(color="black", width=1),
            label=df_nodes.label,
            color=[node_palette[idx % len(node_palette)] for idx in df_nodes.id],
        ),
        link=dict(
            source=df_links.source,
            target=df_links.target,
            value=df_links.value,
            color=[link_palette[idx % len(link_palette)] for idx in df_links.id],
        ),
        visible=True,
    )

    fig = go.Figure(data=[sankey])
    notes = [
        "Height of bars and traces is proportional to vote count.",
    ]
    fig.add_annotation(
        x=0,
        y=-0.2,
        xref="paper",
        yref="paper",
        text="<br>".join(notes),
        align="left",
        showarrow=False,
        font_size=15,
    )
    fig.update_layout(margin=dict(l=0, t=0))
    return fig
