"""Utilities for creating Sankey plots of pystv results."""
from dataclasses import dataclass, field

import plotly.express as px
import plotly.graph_objects as go


class PyStvVizError(Exception):
    """Error in PyStvViz."""


NODE_PALETTE = px.colors.qualitative.Dark2
LINK_PALETTE = px.colors.qualitative.Set2
NUM_COLORS = len(NODE_PALETTE)


@dataclass
class SankeyData:
    """Container for data useful for creating Sankey Plots."""

    source: list = field(default_factory=list)
    target: list = field(default_factory=list)
    value: list = field(default_factory=list)
    link_color: list = field(default_factory=list)

    labels: list = field(default_factory=list)
    node_color: list = field(default_factory=list)


def results_to_sankey(results, labels=None):
    num_rounds = len(results)
    num_cands = len(results[0].count)

    if labels is not None:
        if len(labels) != num_cands:
            raise PyStvVizError(
                f"Number of candidates {num_cands} does not match labels: {labels}"
            )
    else:
        labels = [f"cand_{i}" for i in range(num_cands)]

    node_color = [NODE_PALETTE[idx % NUM_COLORS] for idx in range(num_cands)]

    data = SankeyData()
    data.node_color = node_color * num_rounds
    data.labels = labels * num_rounds

    for rnd in range(num_rounds - 1):
        offset = rnd * num_cands
        for src in range(num_cands):
            src_left = results[rnd].count[src]
            if src in results[rnd].transfers:
                for tgt, diff in results[rnd].transfers[src].items():
                    data.source.append(src + offset)
                    data.target.append(tgt + offset + num_cands)
                    data.value.append(diff)
                    data.link_color.append(LINK_PALETTE[src % NUM_COLORS])
                    src_left -= diff
            if src_left > 0:
                data.source.append(src + offset)
                data.target.append(src + offset + num_cands)
                data.value.append(src_left)
                data.link_color.append(LINK_PALETTE[src % NUM_COLORS])

    return data


def create_sankey(sankey_data):
    sankey = go.Sankey(
        node=dict(
            thickness=10,
            line=dict(color="black", width=1),
            label=sankey_data.labels,
            color=sankey_data.node_color,
        ),
        link=dict(
            source=sankey_data.source,
            target=sankey_data.target,
            value=sankey_data.value,
            color=sankey_data.link_color,
        ),
        visible=True,
    )

    fig = go.Figure(data=[sankey])
    notes = [
        "Height of bars and traces is proportional to vote count.",
    ]
    fig.add_annotation(
        x=0,
        y=-0.2,
        xref="paper",
        yref="paper",
        text="<br>".join(notes),
        align="left",
        showarrow=False,
        font_size=15,
    )
    fig.update_layout(margin=dict(l=0, t=0))
    return fig
