import pytest
from permuta.misc import ordered_set_partitions


def test_ordered_set_partitions():
    it = ordered_set_partitions([1, 2, 3], [2, 1])
    assert [[1, 2], [3]] == next(it)
    assert [[1, 3], [2]] == next(it)
    assert [[2, 3], [1]] == next(it)
    with pytest.raises(StopIteration):
        next(it)

    it = ordered_set_partitions([1, 2, 3], [2, 1])
    assert [[1, 2], [3]] == next(it)
    assert [[1, 3], [2]] == next(it)
    assert [[2, 3], [1]] == next(it)
    with pytest.raises(StopIteration):
        next(it)

    lst = list(ordered_set_partitions([2, 3, 4, 5, 6], [2, 1, 2]))
    assert [
        [[2, 3], [4], [5, 6]],
        [[2, 3], [5], [4, 6]],
        [[2, 3], [6], [4, 5]],
        [[2, 4], [3], [5, 6]],
        [[2, 4], [5], [3, 6]],
        [[2, 4], [6], [3, 5]],
        [[2, 5], [3], [4, 6]],
        [[2, 5], [4], [3, 6]],
        [[2, 5], [6], [3, 4]],
        [[2, 6], [3], [4, 5]],
        [[2, 6], [4], [3, 5]],
        [[2, 6], [5], [3, 4]],
        [[3, 4], [2], [5, 6]],
        [[3, 4], [5], [2, 6]],
        [[3, 4], [6], [2, 5]],
        [[3, 5], [2], [4, 6]],
        [[3, 5], [4], [2, 6]],
        [[3, 5], [6], [2, 4]],
        [[3, 6], [2], [4, 5]],
        [[3, 6], [4], [2, 5]],
        [[3, 6], [5], [2, 4]],
        [[4, 5], [2], [3, 6]],
        [[4, 5], [3], [2, 6]],
        [[4, 5], [6], [2, 3]],
        [[4, 6], [2], [3, 5]],
        [[4, 6], [3], [2, 5]],
        [[4, 6], [5], [2, 3]],
        [[5, 6], [2], [3, 4]],
        [[5, 6], [3], [2, 4]],
        [[5, 6], [4], [2, 3]],
    ] == sorted(lst)

    lst = list(ordered_set_partitions([2, 3, 4, 5, 6], [3, 1, 1]))
    assert [
        [[2, 3, 4], [5], [6]],
        [[2, 3, 4], [6], [5]],
        [[2, 3, 5], [4], [6]],
        [[2, 3, 5], [6], [4]],
        [[2, 3, 6], [4], [5]],
        [[2, 3, 6], [5], [4]],
        [[2, 4, 5], [3], [6]],
        [[2, 4, 5], [6], [3]],
        [[2, 4, 6], [3], [5]],
        [[2, 4, 6], [5], [3]],
        [[2, 5, 6], [3], [4]],
        [[2, 5, 6], [4], [3]],
        [[3, 4, 5], [2], [6]],
        [[3, 4, 5], [6], [2]],
        [[3, 4, 6], [2], [5]],
        [[3, 4, 6], [5], [2]],
        [[3, 5, 6], [2], [4]],
        [[3, 5, 6], [4], [2]],
        [[4, 5, 6], [2], [3]],
        [[4, 5, 6], [3], [2]],
    ] == sorted(lst)
