"""Minimalist implementation of Single Transferable Vote."""
import collections
import itertools

import numpy as np


class PyStvError(Exception):
    """Error in PyStv."""


RoundResult = collections.namedtuple(
    "RoundResults", ["count", "elected", "eliminated", "transfers"]
)


def run_stv(ballots, votes, num_seats):
    ballots = validate_and_standardize_ballots(ballots)
    votes = np.array(votes)[:, None]

    num_slots = np.max(ballots) + 1  # zero is reserved for empty marks

    weights = np.zeros_like(ballots, float)
    weights[:, 0] = 1

    # TODO: Need to add 1 if it evenly divides
    votes_needed = int(np.ceil(np.sum(votes) / (num_seats + 1)))

    status = np.zeros(num_slots, int)
    status[0] = -2
    valid = np.ones_like(ballots, bool)

    round_info = []
    while np.count_nonzero(status > 0) < num_seats:
        counts = np.bincount(ballots.ravel(), (weights * votes).ravel(), num_slots)
        counts_masked = np.ma.array(counts, mask=(status != 0))
        elected_mask = counts_masked >= votes_needed

        elected = []
        eliminated = []

        if np.count_nonzero(status >= 0) == num_seats:
            to_remove = []
            elected = np.nonzero(status == 0)[0]
            multipliers = []
            status[status == 0] = 1
        elif elected_mask.any():
            to_remove = np.nonzero(elected_mask)[0]
            elected = to_remove
            multipliers = [votes_needed / counts_masked[c] for c in to_remove]
            status[to_remove] = 1
        else:  # eliminate candidate with fewest votes (random selection if >1)
            min_counts = np.where(counts_masked == counts_masked.min())[0]
            to_remove = [np.random.choice(min_counts)]
            eliminated = to_remove
            multipliers = [0]
            status[to_remove] = -1

        # Distribute votes from elected/eliminated candidates to the next active
        # lower ranked candidate.  Elected candidates transfer excess votes, while
        # eliminated candidates transfer all votes.
        transfers = {}
        for cand, mult in zip(to_remove, multipliers):
            orig = first_nonzero(valid, axis=1)  # Original candidate on each ballot
            valid[ballots == cand] = False
            next = first_nonzero(valid, axis=1)  # Next active candidate on each ballot

            # Ballots where the candidate changed
            o_rows = orig != next
            # Ballots the candidate changed, but where no active candidate remains
            n_rows = o_rows & (next != -1)

            # Votes are transferred to next candidate.
            weights[n_rows, next[n_rows]] = weights[n_rows, orig[n_rows]] * (1 - mult)
            # Votes are transferred from original candidate.
            weights[o_rows, orig[o_rows]] *= mult

            transfers_cand = collections.defaultdict(int)
            slicer = n_rows, next[n_rows]
            for c, v in zip(ballots[slicer], (weights * votes)[slicer]):
                if v:
                    transfers_cand[c] += v

            if transfers_cand:
                transfers[cand] = dict(transfers_cand)

        if counts.sum() != votes.sum():
            raise PyStvError("Final round count total does not equal original votes")

        round_info.append(
            RoundResult(counts.tolist(), list(elected), list(eliminated), transfers)
        )
    return round_info


def first_nonzero(arr, axis, invalid_val=-1):
    mask = arr != 0
    return np.where(mask.any(axis=axis), mask.argmax(axis=axis), invalid_val)


def validate_and_standardize_ballots(ballots):
    ballots = np.array(list(itertools.zip_longest(*ballots, fillvalue=0))).T

    # Add a 0 at the end of every ballot.
    ballots = np.pad(ballots, ((0, 0), (0, 1)))

    if ballots.ndim != 2:
        raise ValueError("Ballot data has wrong dim: %s" % ballots.ndim)

    non_negative_rankings = ballots >= 0
    if not non_negative_rankings.all():
        bad_ballots = ~non_negative_rankings.all(axis=1)
        bad_indices = np.nonzero(bad_ballots)[0].tolist()
        raise ValueError("Negative rankings on ballots: %s" % bad_indices)

    first = ballots[:, :-1] == 0
    second = ballots[:, 1:] == 0
    continuous_rankings = ~first | second

    if not continuous_rankings.all():
        bad_ballots = ~continuous_rankings.all(axis=1)
        bad_indices = np.nonzero(bad_ballots)[0].tolist()
        raise ValueError("Skipped rankings on ballots: %s" % bad_indices)

    return ballots
