"""Tests for `pystv` package."""
import pytest
from click.testing import CliRunner
from numpy.testing import assert_equal

import pystv
from pystv import cli

from .utils import assert_rr_equal

RR = pystv.RoundResult


# TODO: Do not required all ballots same length
FAIRVOTE_BALLOTS = [
    ([1, 2, 3], 625),  # R4
    ([1, 2, 4], 125),  # R4
    ([1, 2, 5], 250),  # R4; Last is any unviable candidate
    ([1, 2, 6], 250),  # R4; Last is any unviable candidate
    ([1, 5, 3], 500),  # R3
    ([1, 5, 4], 500),  # R3
    ([1, 3, 0], 250),
    ([2, 3, 0], 875),  # R4
    ([2, 4, 0], 175),  # R4
    ([2, 5, 0], 350),  # R4; Last is any unviable candidate
    ([2, 6, 0], 350),  # R4; Last is any unviable candidate
    ([3, 0, 0], 1300),
    ([4, 0, 0], 1300),
    ([5, 2, 3], 625),  # R3, R4
    ([5, 2, 4], 125),  # R3, R4
    ([5, 2, 6], 500),  # R3, R4; Last is any unviable candidate
    ([5, 3, 0], 100),  # R3
    ([6, 3, 0], 580),
    ([6, 4, 0], 300),
    ([6, 2, 3], 50),  # R4
    ([6, 2, 4], 10),  # R4
    ([6, 2, 5], 40),  # R4; Last is any unviable candidate
    ([6, 5, 3], 10),  # R3
    ([6, 5, 4], 10),  # R3
]


def test_2cands_1seat():
    ballots = [
        [2, 1],
        [1, 2],
    ]
    votes = [2, 1]
    results = pystv.run_stv(ballots, votes, num_seats=1)

    assert len(results) == 1
    assert_rr_equal(results[0], RR([0, 1, 2], [2], [], {}))


def test_2cands_1seat_undervote():
    ballots = [
        [2, 0],
        [2, 1],
        [1, 2],
    ]
    votes = [1, 1, 1]
    results = pystv.run_stv(ballots, votes, num_seats=1)

    assert len(results) == 1
    assert_rr_equal(results[0], RR([0, 1, 2], [2], [], {}))


def test_3cands_2seat_undervote():
    ballots = [
        [4, 0],
        [3, 0],
        [2, 4],
        [2, 0],
        [1, 4],
        [1, 0],
    ]
    votes = [3, 3, 1, 1, 1, 1]
    results = pystv.run_stv(ballots, votes, num_seats=2)

    assert len(results) == 4
    assert_rr_equal(results[0], RR([0, 2, 2, 3, 3], [], [1], {1: {0: 1, 4: 1}}))
    assert_rr_equal(results[1], RR([1, 0, 2, 3, 4], [4], [], {}))
    assert_rr_equal(results[2], RR([1, 0, 2, 3, 4], [], [2], {2: {0: 2}}))
    assert_rr_equal(results[3], RR([3, 0, 0, 3, 4], [3], [], {}))


def test_3cands_2seats_1round():
    ballots = [
        [2, 1, 3],
        [1, 2, 3],
    ]
    votes = [3, 2]
    results = pystv.run_stv(ballots, votes, num_seats=2)

    assert len(results) == 1
    assert_rr_equal(results[0], RR([0, 2, 3, 0], [1, 2], [], {2: {3: 1}}))


def test_3cands_1seat_multiround():
    ballots = [
        [1, 2, 3],
        [2, 1, 3],
        [3, 1, 2],
    ]
    votes = [2, 2, 1]
    results = pystv.run_stv(ballots, votes, num_seats=1)

    assert len(results) == 2
    assert_rr_equal(results[0], RR([0, 2, 2, 1], [], [3], {3: {1: 1}}))
    assert_rr_equal(results[1], RR([0, 3, 2, 0], [1], [], {}))


def test_3cands_2seats_multiround():
    ballots = [
        [1, 3, 2],
        [2, 1, 3],
        [3, 1, 2],
        [3, 2, 1],
    ]
    votes = [2, 4, 1, 2]
    results = pystv.run_stv(ballots, votes, num_seats=2)
    assert len(results) == 1
    assert_rr_equal(results[0], RR([0, 2, 4, 3], [2, 3], [], {2: {1: 1}}))


def test_3cands_2seats_multiround_with_adjust():
    ballots = [
        [1, 3, 2],
        [2, 1, 3],
        [3, 2, 1],
        [3, 2, 1],
    ]
    votes = [2, 5, 1, 2]
    results = pystv.run_stv(ballots, votes, num_seats=2)
    assert len(results) == 3
    assert_rr_equal(results[0], RR([0, 2, 5, 3], [2], [], {2: {1: 1}}))
    assert_rr_equal(results[1], RR([0, 3, 4, 3], [], [1], {1: {3: 3}}))
    assert_rr_equal(results[2], RR([0, 0, 4, 6], [3], [], {}))


def test_fairvote_example():
    """Example from FairVote's website.

    https://fairvote.org/archives/multi_winner_rcv_example/
    """
    ballots, votes = zip(*FAIRVOTE_BALLOTS)
    results = pystv.run_stv(ballots, votes, num_seats=3)
    assert len(results) == 5
    assert_rr_equal(
        results[0],
        RR(
            [0, 2500, 1750, 1300, 1300, 1350, 1000],
            [1],
            [],
            {1: {2: 100, 3: 20, 5: 80}},
        ),
    )
    assert_rr_equal(
        results[1],
        RR(
            [0, 2300, 1850, 1320, 1300, 1430, 1000],
            [],
            [6],
            {6: {2: 100, 3: 580, 4: 300, 5: 20}},
        ),
    )
    assert_rr_equal(
        results[2],
        RR(
            [0, 2300, 1950, 1900, 1600, 1450, 0], [], [5], {5: {2: 1250, 3: 150, 4: 50}}
        ),
    )
    assert_rr_equal(
        results[3],
        RR([0, 2300, 3200, 2050, 1650, 0, 0], [2], [], {2: {0: 360, 3: 450, 4: 90}}),
    )
    assert_rr_equal(
        results[4], RR([360, 2300, 2300, 2500, 1740, 0, 0], [3], [], {3: {0: 200}})
    )


def test_validate_and_standardize_ballots_ok():
    ballots = [[1, 0, 0], [1, 2, 0], [1, 2, 3]]
    result = pystv.validate_and_standardize_ballots(ballots)
    assert_equal(result, [[1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 0]])


def test_validate_and_standardize_ballots_ragged():
    ballots = [[1, 0, 0], [1, 2], [1, 2, 3]]
    result = pystv.validate_and_standardize_ballots(ballots)
    assert_equal(result, [[1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 0]])


def test_validate_and_standardize_ballots_negative():
    ballots = [[1, 0, 0], [1, 2, -1], [1, 2, 3]]
    with pytest.raises(ValueError, match=r"Negative rankings on ballots: \[1\]"):
        pystv.validate_and_standardize_ballots(ballots)


def test_validate_and_standardize_ballots_invalid_ranking():
    ballots = [[1, 0, 0], [0, 1, 0], [1, 2, 0], [0, 0, 1]]
    with pytest.raises(ValueError, match=r"Skipped rankings on ballots: \[1, 3\]"):
        pystv.validate_and_standardize_ballots(ballots)


def test_command_line_interface():
    """Test the CLI."""
    runner = CliRunner()
    result = runner.invoke(cli.main)
    assert result.exit_code == 0
    assert "pystv.cli.main" in result.output
    help_result = runner.invoke(cli.main, ["--help"])
    assert help_result.exit_code == 0
    assert "--help  Show this message and exit." in help_result.output
