try:
    from unittest.mock import MagicMock, create_autospec
except ImportError:
    from mock import create_autospec, MagicMock

import asyncio
import time

import pytest
from reretry.api import _is_async, retry, retry_call


def test_retry(monkeypatch):
    mock_sleep_time = [0]

    def mock_sleep(seconds):
        mock_sleep_time[0] += seconds

    monkeypatch.setattr(time, "sleep", mock_sleep)

    hit = [0]

    tries = 5
    delay = 1
    backoff = 2

    @retry(tries=tries, delay=delay, backoff=backoff)
    def f():
        hit[0] += 1
        1 / 0

    with pytest.raises(ZeroDivisionError):
        f()
    assert hit[0] == tries
    assert mock_sleep_time[0] == sum(delay * backoff**i for i in range(tries - 1))


def test_tries_inf():
    hit = [0]
    target = 10

    @retry(tries=float("inf"))
    def f():
        hit[0] += 1
        if hit[0] == target:
            return target
        else:
            raise ValueError

    assert f() == target


def test_tries_minus1():
    hit = [0]
    target = 10

    @retry(tries=-1)
    def f():
        hit[0] += 1
        if hit[0] == target:
            return target
        else:
            raise ValueError

    assert f() == target


def test_max_delay(monkeypatch):
    mock_sleep_time = [0]

    def mock_sleep(seconds):
        mock_sleep_time[0] += seconds

    monkeypatch.setattr(time, "sleep", mock_sleep)

    hit = [0]

    tries = 5
    delay = 1
    backoff = 2
    max_delay = delay  # Never increase delay

    @retry(tries=tries, delay=delay, max_delay=max_delay, backoff=backoff)
    def f():
        hit[0] += 1
        1 / 0

    with pytest.raises(ZeroDivisionError):
        f()
    assert hit[0] == tries
    assert mock_sleep_time[0] == delay * (tries - 1)


def test_fixed_jitter(monkeypatch):
    mock_sleep_time = [0]

    def mock_sleep(seconds):
        mock_sleep_time[0] += seconds

    monkeypatch.setattr(time, "sleep", mock_sleep)

    hit = [0]

    tries = 10
    jitter = 1

    @retry(tries=tries, jitter=jitter)
    def f():
        hit[0] += 1
        1 / 0

    with pytest.raises(ZeroDivisionError):
        f()
    assert hit[0] == tries
    assert mock_sleep_time[0] == sum(range(tries - 1))


def test_retry_call():
    f_mock = MagicMock(side_effect=RuntimeError)
    tries = 2
    try:
        retry_call(f_mock, exceptions=RuntimeError, tries=tries)
    except RuntimeError:
        pass

    assert f_mock.call_count == tries


def test_retry_call_2():
    side_effect = [RuntimeError, RuntimeError, 3]
    f_mock = MagicMock(side_effect=side_effect)
    tries = 5
    result = None
    try:
        result = retry_call(f_mock, exceptions=RuntimeError, tries=tries)
    except RuntimeError:
        pass

    assert result == 3
    assert f_mock.call_count == len(side_effect)


def test_retry_call_with_args():
    def f(value=0):
        if value < 0:
            return value
        else:
            raise RuntimeError

    return_value = -1
    result = None
    f_mock = MagicMock(spec=f, return_value=return_value)
    try:
        result = retry_call(f_mock, fargs=[return_value])
    except RuntimeError:
        pass

    assert result == return_value
    assert f_mock.call_count == 1


def test_retry_call_with_kwargs():
    def f(value=0):
        if value < 0:
            return value
        else:
            raise RuntimeError

    kwargs = {"value": -1}
    result = None
    f_mock = MagicMock(spec=f, return_value=kwargs["value"])
    try:
        result = retry_call(f_mock, fkwargs=kwargs)
    except RuntimeError:
        pass

    assert result == kwargs["value"]
    assert f_mock.call_count == 1


def test_retry_call_with_fail_callback():
    def f():
        raise RuntimeError

    def cb(error):
        pass

    callback_mock = MagicMock(spec=cb)
    try:
        retry_call(f, fail_callback=callback_mock, tries=2)
    except RuntimeError:
        pass

    callback_mock.assert_called()


def test_is_async():
    async def async_func():
        pass

    def non_async_func():
        pass

    def generator():
        yield


    assert _is_async(async_func)
    assert not _is_async(non_async_func)
    assert not _is_async(generator)
    assert not _is_async(generator())
    assert not _is_async(MagicMock(spec=non_async_func, return_value=-1))


@pytest.mark.asyncio
async def test_async():
    attempts = 1
    raised = False

    @retry(tries=2)
    async def f():
        await asyncio.sleep(0.1)
        nonlocal attempts, raised
        if attempts:
            raised = True
            attempts -= 1
            raise RuntimeError
        return True

    assert await f()
    assert raised
    assert attempts == 0


def test_check_params():
    with pytest.raises(AssertionError):
        retry_call(lambda: None, show_traceback=True, logger=None)

    async def async_func():
        pass

    def non_async_func():
        pass

    with pytest.raises(AssertionError):
        retry_call(async_func, fail_callback=non_async_func)
