"""Minimalist implementation of Single Transferable Vote."""
import collections

import numpy as np

RoundResult = collections.namedtuple("RoundResults", ["count", "elected", "eliminated"])


def run_stv(ballots, votes, num_seats):
    ballots = validate_and_standardize_ballots(ballots)
    votes = np.array(votes)[:, None]

    num_slots = np.max(ballots) + 1  # zero is reserved for empty marks

    weights = np.zeros_like(ballots, float)
    weights[:, 0] = 1

    votes_needed = int(np.ceil(np.sum(votes) / (num_seats + 1)))
    # votes_needed = int(np.sum(votes) / (num_seats + 1))

    status = np.zeros(num_slots, int)
    status[0] = -2
    valid = np.ones_like(ballots, bool)

    round_info = []
    while np.count_nonzero(status > 0) < num_seats:
        counts = np.bincount(ballots.ravel(), (weights * votes).ravel(), num_slots)
        counts = np.ma.array(counts, mask=(status != 0))
        elected = counts >= votes_needed

        if elected.any():
            to_remove = np.nonzero(elected)[0]
            multipliers = [votes_needed / counts[c] for c in to_remove]
            status[to_remove] = 1
        else:  # eliminate candidate with fewest votes (random selection if >1)
            eliminated = np.where(counts == counts.min())[0]
            to_remove = [np.random.choice(eliminated)]
            multipliers = [0]
            status[to_remove] = -1

        # Distribute votes from elected/eliminated candidates to the next active
        # lower ranked candidate.  Elected candidates transfer excess votes, while
        # eliminated candidates transfer all votes.
        for cand, mult in zip(to_remove, multipliers):
            orig = first_nonzero(valid, axis=1)  # Original candidate on each ballot
            valid[ballots == cand] = False
            next = first_nonzero(valid, axis=1)  # Next active candidate on each ballot

            # Ballots where the candidate changed
            o_rows = orig != next
            # Ballots the candidate changed, but where no active candidate remains
            n_rows = o_rows & (next != -1)

            # Votes are transferred to next candidate.
            weights[n_rows, next[n_rows]] = weights[n_rows, orig[n_rows]] * (1 - mult)
            # Votes are transferred from original candidate.
            weights[o_rows, orig[o_rows]] *= mult

        round_info.append(
            RoundResult(counts, np.nonzero(status == 1)[0], np.nonzero(status == -1)[0])
        )
    return round_info


def first_nonzero(arr, axis, invalid_val=-1):
    mask = arr != 0
    return np.where(mask.any(axis=axis), mask.argmax(axis=axis), invalid_val)


def validate_and_standardize_ballots(ballots):
    ballots = np.asarray(ballots)

    if ballots.ndim != 2:
        raise ValueError("Ballot data has wrong dim: %s" % ballots.ndim)

    non_negative_rankings = ballots >= 0
    if not non_negative_rankings.all():
        bad_ballots = ~non_negative_rankings.all(axis=1)
        bad_indices = np.nonzero(bad_ballots)[0].tolist()
        raise ValueError("Negative rankings on ballots: %s" % bad_indices)

    first = ballots[:, :-1] == 0
    second = ballots[:, 1:] == 0
    continuous_rankings = ~first | second

    if not continuous_rankings.all():
        bad_ballots = ~continuous_rankings.all(axis=1)
        bad_indices = np.nonzero(bad_ballots)[0].tolist()
        raise ValueError("Skipped rankings on ballots: %s" % bad_indices)

    return ballots
