import pytest
import numpy as np
from numpy.testing import assert_allclose

import keras
from keras.utils.test_utils import layer_test
from keras.layers import recurrent
from keras.layers import embeddings
from keras.models import Sequential
from keras.models import Model
from keras.engine import Input
from keras.layers import Masking
from keras import regularizers
from keras import backend as K

num_samples, timesteps, embedding_dim, units = 2, 5, 4, 3
embedding_num = 12


rnn_test = pytest.mark.parametrize('layer_class',
                                   [recurrent.SimpleRNN,
                                    recurrent.GRU,
                                    recurrent.LSTM])


rnn_cell_test = pytest.mark.parametrize('cell_class',
                                        [recurrent.SimpleRNNCell,
                                         recurrent.GRUCell,
                                         recurrent.LSTMCell])


@rnn_test
def test_return_sequences(layer_class):
    layer_test(layer_class,
               kwargs={'units': units,
                       'return_sequences': True},
               input_shape=(num_samples, timesteps, embedding_dim))


@rnn_test
def test_dynamic_behavior(layer_class):
    layer = layer_class(units, input_shape=(None, embedding_dim))
    model = Sequential()
    model.add(layer)
    model.compile('sgd', 'mse')
    x = np.random.random((num_samples, timesteps, embedding_dim))
    y = np.random.random((num_samples, units))
    model.train_on_batch(x, y)


@rnn_test
def DISABLED_test_stateful_invalid_use(layer_class):
    layer = layer_class(units,
                        stateful=True,
                        batch_input_shape=(num_samples,
                                           timesteps,
                                           embedding_dim))
    model = Sequential()
    model.add(layer)
    model.compile('sgd', 'mse')
    x = np.random.random((num_samples * 2, timesteps, embedding_dim))
    y = np.random.random((num_samples * 2, units))
    with pytest.raises(ValueError):
        model.fit(x, y)
    with pytest.raises(ValueError):
        model.predict(x, batch_size=num_samples + 1)


@rnn_test
@pytest.mark.skipif((K.backend() in ['theano']),
                    reason='Not supported.')
def test_dropout(layer_class):
    for unroll in [True, False]:
        layer_test(layer_class,
                   kwargs={'units': units,
                           'dropout': 0.1,
                           'recurrent_dropout': 0.1,
                           'unroll': unroll},
                   input_shape=(num_samples, timesteps, embedding_dim))

        # Test that dropout is applied during training
        x = K.ones((num_samples, timesteps, embedding_dim))
        layer = layer_class(units, dropout=0.5, recurrent_dropout=0.5,
                            input_shape=(timesteps, embedding_dim))
        y = layer(x)

        y = layer(x, training=True)

        # Test that dropout is not applied during testing
        x = np.random.random((num_samples, timesteps, embedding_dim))
        layer = layer_class(units, dropout=0.5, recurrent_dropout=0.5,
                            unroll=unroll,
                            input_shape=(timesteps, embedding_dim))
        model = Sequential([layer])
        y1 = model.predict(x)
        y2 = model.predict(x)
        assert_allclose(y1, y2)


@rnn_test
def test_statefulness(layer_class):
    model = Sequential()
    model.add(embeddings.Embedding(embedding_num, embedding_dim,
                                   mask_zero=True,
                                   input_length=timesteps,
                                   batch_input_shape=(num_samples, timesteps)))
    layer = layer_class(units, return_sequences=False,
                        stateful=True,
                        weights=None)
    model.add(layer)
    model.compile(optimizer='sgd', loss='mse')
    out1 = model.predict(np.ones((num_samples, timesteps)))
    assert(out1.shape == (num_samples, units))

    # train once so that the states change
    model.train_on_batch(np.ones((num_samples, timesteps)),
                         np.ones((num_samples, units)))
    out2 = model.predict(np.ones((num_samples, timesteps)))

    # if the state is not reset, output should be different
    assert(out1.max() != out2.max())

    # check that output changes after states are reset
    # (even though the model itself didn't change)
    layer.reset_states()
    out3 = model.predict(np.ones((num_samples, timesteps)))
    assert(out2.max() != out3.max())

    # check that container-level reset_states() works
    model.reset_states()
    out4 = model.predict(np.ones((num_samples, timesteps)))
    assert_allclose(out3, out4, atol=1e-5)

    # check that the call to `predict` updated the states
    out5 = model.predict(np.ones((num_samples, timesteps)))
    assert(out4.max() != out5.max())


@rnn_test
def test_masking_correctness(layer_class):
    # Check masking: output with left padding and right padding
    # should be the same.
    model = Sequential()
    model.add(embeddings.Embedding(embedding_num, embedding_dim,
                                   mask_zero=True,
                                   input_length=timesteps,
                                   batch_input_shape=(num_samples, timesteps)))
    layer = layer_class(units, return_sequences=False)
    model.add(layer)
    model.compile(optimizer='sgd', loss='mse')

    left_padded_input = np.ones((num_samples, timesteps))
    left_padded_input[0, :1] = 0
    left_padded_input[1, :2] = 0
    out6 = model.predict(left_padded_input)

    right_padded_input = np.ones((num_samples, timesteps))
    right_padded_input[0, -1:] = 0
    right_padded_input[1, -2:] = 0
    out7 = model.predict(right_padded_input)

    assert_allclose(out7, out6, atol=1e-5)


@pytest.mark.skipif(K.backend() == 'cntk', reason='Not supported.')
def test_masking_correctness_output_not_equal_to_first_state():

    class Cell(keras.layers.Layer):

        def __init__(self):
            self.state_size = None
            self.output_size = None
            super(Cell, self).__init__()

        def build(self, input_shape):
            self.state_size = input_shape[-1]
            self.output_size = input_shape[-1]

        def call(self, inputs, states):
            return inputs, [s + 1 for s in states]

    num_samples = 5
    num_timesteps = 4
    state_size = input_size = 3  # also equal to `output_size`

    # random inputs and state values
    x_vals = np.random.random((num_samples, num_timesteps, input_size))
    # last timestep masked for first sample (all zero inputs masked by Masking layer)
    x_vals[0, -1, :] = 0
    s_initial_vals = np.random.random((num_samples, state_size))

    # final outputs equal to last inputs
    y_vals_expected = x_vals[:, -1].copy()
    # except for first sample, where it is equal to second to last value due to mask
    y_vals_expected[0] = x_vals[0, -2]

    s_final_vals_expected = s_initial_vals.copy()
    # states are incremented `num_timesteps - 1` times for first sample
    s_final_vals_expected[0] += (num_timesteps - 1)
    # and `num_timesteps - 1` times for remaining samples
    s_final_vals_expected[1:] += num_timesteps

    for unroll in [True, False]:
        x = Input((num_timesteps, input_size), name="x")
        x_masked = Masking()(x)
        s_initial = Input((state_size,), name="s_initial")
        y, s_final = recurrent.RNN(Cell(),
                                   return_state=True,
                                   unroll=unroll)(x_masked, initial_state=s_initial)
        model = Model([x, s_initial], [y, s_final])
        model.compile(optimizer='sgd', loss='mse')

        y_vals, s_final_vals = model.predict([x_vals, s_initial_vals])
        assert_allclose(y_vals,
                        y_vals_expected,
                        err_msg="Unexpected output for unroll={}".format(unroll))
        assert_allclose(s_final_vals,
                        s_final_vals_expected,
                        err_msg="Unexpected state for unroll={}".format(unroll))


@pytest.mark.skipif(K.backend() == 'cntk', reason='Not supported.')
def test_masking_correctness_output_size_not_equal_to_first_state_size():

    class Cell(keras.layers.Layer):

        def __init__(self):
            self.state_size = None
            self.output_size = None
            super(Cell, self).__init__()

        def build(self, input_shape):
            self.state_size = input_shape[-1]
            self.output_size = input_shape[-1] * 2

        def call(self, inputs, states):
            return keras.layers.concatenate([inputs] * 2), [s + 1 for s in states]

    num_samples = 5
    num_timesteps = 6
    input_size = state_size = 7

    # random inputs and state values
    x_vals = np.random.random((num_samples, num_timesteps, input_size))
    # last timestep masked for first sample (all zero inputs masked by Masking layer)
    x_vals[0, -1, :] = 0
    s_initial_vals = np.random.random((num_samples, state_size))

    # final outputs equal to last inputs concatenated
    y_vals_expected = np.concatenate([x_vals[:, -1]] * 2, axis=-1)
    # except for first sample, where it is equal to second to last value due to mask
    y_vals_expected[0] = np.concatenate([x_vals[0, -2]] * 2, axis=-1)

    s_final_vals_expected = s_initial_vals.copy()
    # states are incremented `num_timesteps - 1` times for first sample
    s_final_vals_expected[0] += (num_timesteps - 1)
    # and `num_timesteps - 1` times for remaining samples
    s_final_vals_expected[1:] += num_timesteps

    for unroll in [True, False]:
        x = Input((num_timesteps, input_size), name="x")
        x_masked = Masking()(x)
        s_initial = Input((state_size,), name="s_initial")
        y, s_final = recurrent.RNN(Cell(),
                                   return_state=True,
                                   unroll=unroll)(x_masked, initial_state=s_initial)
        model = Model([x, s_initial], [y, s_final])
        model.compile(optimizer='sgd', loss='mse')

        y_vals, s_final_vals = model.predict([x_vals, s_initial_vals])
        assert_allclose(y_vals,
                        y_vals_expected,
                        err_msg="Unexpected output for unroll={}".format(unroll))
        assert_allclose(s_final_vals,
                        s_final_vals_expected,
                        err_msg="Unexpected state for unroll={}".format(unroll))


@rnn_test
def test_implementation_mode(layer_class):
    for mode in [1, 2]:
        # Without dropout
        layer_test(layer_class,
                   kwargs={'units': units,
                           'implementation': mode},
                   input_shape=(num_samples, timesteps, embedding_dim))
        # With dropout
        layer_test(layer_class,
                   kwargs={'units': units,
                           'implementation': mode,
                           'dropout': 0.1,
                           'recurrent_dropout': 0.1},
                   input_shape=(num_samples, timesteps, embedding_dim))
        # Without bias
        layer_test(layer_class,
                   kwargs={'units': units,
                           'implementation': mode,
                           'use_bias': False},
                   input_shape=(num_samples, timesteps, embedding_dim))


@rnn_test
def test_regularizer(layer_class):
    layer = layer_class(units, return_sequences=False, weights=None,
                        input_shape=(timesteps, embedding_dim),
                        kernel_regularizer=regularizers.l1(0.01),
                        recurrent_regularizer=regularizers.l1(0.01),
                        bias_regularizer='l2')
    layer.build((None, None, embedding_dim))
    assert len(layer.losses) == 3
    assert len(layer.cell.losses) == 3

    layer = layer_class(units, return_sequences=False, weights=None,
                        input_shape=(timesteps, embedding_dim),
                        activity_regularizer='l2')
    assert layer.activity_regularizer
    x = K.variable(np.ones((num_samples, timesteps, embedding_dim)))
    layer(x)


@rnn_test
def test_trainability(layer_class):
    layer = layer_class(units)
    layer.build((None, None, embedding_dim))
    assert len(layer.weights) == 3
    assert len(layer.trainable_weights) == 3
    assert len(layer.non_trainable_weights) == 0
    layer.trainable = False
    assert len(layer.weights) == 3
    assert len(layer.trainable_weights) == 0
    assert len(layer.non_trainable_weights) == 3
    layer.trainable = True
    assert len(layer.weights) == 3
    assert len(layer.trainable_weights) == 3
    assert len(layer.non_trainable_weights) == 0


def test_masking_layer():
    ''' This test based on a previously failing issue here:
    https://github.com/keras-team/keras/issues/1567
    '''
    inputs = np.random.random((6, 3, 4))
    targets = np.abs(np.random.random((6, 3, 5)))
    targets /= targets.sum(axis=-1, keepdims=True)

    model = Sequential()
    model.add(Masking(input_shape=(3, 4)))
    model.add(recurrent.SimpleRNN(units=5, return_sequences=True, unroll=False))
    model.compile(loss='categorical_crossentropy', optimizer='adam')
    model.fit(inputs, targets, epochs=1, batch_size=100, verbose=1)

    model = Sequential()
    model.add(Masking(input_shape=(3, 4)))
    model.add(recurrent.SimpleRNN(units=5, return_sequences=True, unroll=True))
    model.compile(loss='categorical_crossentropy', optimizer='adam')
    model.fit(inputs, targets, epochs=1, batch_size=100, verbose=1)


@rnn_test
def test_from_config(layer_class):
    stateful_flags = (False, True)
    for stateful in stateful_flags:
        l1 = layer_class(units=1, stateful=stateful)
        l2 = layer_class.from_config(l1.get_config())
        assert l1.get_config() == l2.get_config()


@rnn_test
def test_specify_initial_state_keras_tensor(layer_class):
    num_states = 2 if layer_class is recurrent.LSTM else 1

    # Test with Keras tensor
    inputs = Input((timesteps, embedding_dim))
    initial_state = [Input((units,)) for _ in range(num_states)]
    layer = layer_class(units)
    if len(initial_state) == 1:
        output = layer(inputs, initial_state=initial_state[0])
    else:
        output = layer(inputs, initial_state=initial_state)
    assert id(initial_state[0]) in [
        id(x) for x in layer._inbound_nodes[0].input_tensors]

    model = Model([inputs] + initial_state, output)
    model.compile(loss='categorical_crossentropy', optimizer='adam')

    inputs = np.random.random((num_samples, timesteps, embedding_dim))
    initial_state = [np.random.random((num_samples, units))
                     for _ in range(num_states)]
    targets = np.random.random((num_samples, units))
    model.fit([inputs] + initial_state, targets)


@rnn_test
def test_specify_initial_state_non_keras_tensor(layer_class):
    num_states = 2 if layer_class is recurrent.LSTM else 1

    # Test with non-Keras tensor
    inputs = Input((timesteps, embedding_dim))
    initial_state = [K.random_normal_variable((num_samples, units), 0, 1)
                     for _ in range(num_states)]
    layer = layer_class(units)
    output = layer(inputs, initial_state=initial_state)

    model = Model(inputs, output)
    model.compile(loss='categorical_crossentropy', optimizer='adam')

    inputs = np.random.random((num_samples, timesteps, embedding_dim))
    targets = np.random.random((num_samples, units))
    model.fit(inputs, targets)


@rnn_test
def test_reset_states_with_values(layer_class):
    num_states = 2 if layer_class is recurrent.LSTM else 1

    layer = layer_class(units, stateful=True)
    layer.build((num_samples, timesteps, embedding_dim))
    layer.reset_states()
    assert len(layer.states) == num_states
    assert layer.states[0] is not None
    np.testing.assert_allclose(K.eval(layer.states[0]),
                               np.zeros(K.int_shape(layer.states[0])),
                               atol=1e-4)
    state_shapes = [K.int_shape(state) for state in layer.states]
    values = [np.ones(shape) for shape in state_shapes]
    if len(values) == 1:
        values = values[0]
    layer.reset_states(values)
    np.testing.assert_allclose(K.eval(layer.states[0]),
                               np.ones(K.int_shape(layer.states[0])),
                               atol=1e-4)

    # Test fit with invalid data
    with pytest.raises(ValueError):
        layer.reset_states([1] * (len(layer.states) + 1))


@rnn_test
def test_initial_states_as_other_inputs(layer_class):
    num_states = 2 if layer_class is recurrent.LSTM else 1

    # Test with Keras tensor
    main_inputs = Input((timesteps, embedding_dim))
    initial_state = [Input((units,)) for _ in range(num_states)]
    inputs = [main_inputs] + initial_state

    layer = layer_class(units)
    output = layer(inputs)
    assert id(initial_state[0]) in [
        id(x) for x in layer._inbound_nodes[0].input_tensors]

    model = Model(inputs, output)
    model.compile(loss='categorical_crossentropy', optimizer='adam')

    main_inputs = np.random.random((num_samples, timesteps, embedding_dim))
    initial_state = [np.random.random((num_samples, units))
                     for _ in range(num_states)]
    targets = np.random.random((num_samples, units))
    model.train_on_batch([main_inputs] + initial_state, targets)


@rnn_test
def test_specify_state_with_masking(layer_class):
    ''' This test based on a previously failing issue here:
    https://github.com/keras-team/keras/issues/1567
    '''
    num_states = 2 if layer_class is recurrent.LSTM else 1

    inputs = Input((timesteps, embedding_dim))
    _ = Masking()(inputs)
    initial_state = [Input((units,)) for _ in range(num_states)]
    output = layer_class(units)(inputs, initial_state=initial_state)

    model = Model([inputs] + initial_state, output)
    model.compile(loss='categorical_crossentropy', optimizer='adam')

    inputs = np.random.random((num_samples, timesteps, embedding_dim))
    initial_state = [np.random.random((num_samples, units))
                     for _ in range(num_states)]
    targets = np.random.random((num_samples, units))
    model.fit([inputs] + initial_state, targets)


@rnn_test
def test_return_state(layer_class):
    num_states = 2 if layer_class is recurrent.LSTM else 1

    inputs = Input(batch_shape=(num_samples, timesteps, embedding_dim))
    layer = layer_class(units, return_state=True, stateful=True)
    outputs = layer(inputs)
    output, state = outputs[0], outputs[1:]
    assert len(state) == num_states
    model = Model(inputs, state[0])

    inputs = np.random.random((num_samples, timesteps, embedding_dim))
    state = model.predict(inputs)
    np.testing.assert_allclose(K.eval(layer.states[0]), state, atol=1e-4)


@rnn_test
def test_state_reuse(layer_class):
    inputs = Input(batch_shape=(num_samples, timesteps, embedding_dim))
    layer = layer_class(units, return_state=True, return_sequences=True)
    outputs = layer(inputs)
    output, state = outputs[0], outputs[1:]
    output = layer_class(units)(output, initial_state=state)
    model = Model(inputs, output)

    inputs = np.random.random((num_samples, timesteps, embedding_dim))
    outputs = model.predict(inputs)


@rnn_test
@pytest.mark.skipif((K.backend() in ['theano']),
                    reason='Not supported.')
def test_state_reuse_with_dropout(layer_class):
    input1 = Input(batch_shape=(num_samples, timesteps, embedding_dim))
    layer = layer_class(units, return_state=True, return_sequences=True, dropout=0.2)
    state = layer(input1)[1:]

    input2 = Input(batch_shape=(num_samples, timesteps, embedding_dim))
    output = layer_class(units)(input2, initial_state=state)
    model = Model([input1, input2], output)

    inputs = [np.random.random((num_samples, timesteps, embedding_dim)),
              np.random.random((num_samples, timesteps, embedding_dim))]
    outputs = model.predict(inputs)


def test_minimal_rnn_cell_non_layer():

    class MinimalRNNCell(object):

        def __init__(self, units, input_dim):
            self.units = units
            self.state_size = units
            self.kernel = keras.backend.variable(
                np.random.random((input_dim, units)))

        def call(self, inputs, states):
            prev_output = states[0]
            output = keras.backend.dot(inputs, self.kernel) + prev_output
            return output, [output]

    # Basic test case.
    cell = MinimalRNNCell(32, 5)
    x = keras.Input((None, 5))
    layer = recurrent.RNN(cell)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    # Test stacking.
    cells = [MinimalRNNCell(8, 5),
             MinimalRNNCell(32, 8),
             MinimalRNNCell(32, 32)]
    layer = recurrent.RNN(cells)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))


def test_minimal_rnn_cell_non_layer_multiple_states():

    class MinimalRNNCell(object):

        def __init__(self, units, input_dim):
            self.units = units
            self.state_size = (units, units)
            self.kernel = keras.backend.variable(
                np.random.random((input_dim, units)))

        def call(self, inputs, states):
            prev_output_1 = states[0]
            prev_output_2 = states[1]
            output = keras.backend.dot(inputs, self.kernel)
            output += prev_output_1
            output -= prev_output_2
            return output, [output * 2, output * 3]

    # Basic test case.
    cell = MinimalRNNCell(32, 5)
    x = keras.Input((None, 5))
    layer = recurrent.RNN(cell)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    # Test stacking.
    cells = [MinimalRNNCell(8, 5),
             MinimalRNNCell(16, 8),
             MinimalRNNCell(32, 16)]
    layer = recurrent.RNN(cells)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))


def test_minimal_rnn_cell_layer():

    class MinimalRNNCell(keras.layers.Layer):

        def __init__(self, units, **kwargs):
            self.units = units
            self.state_size = units
            super(MinimalRNNCell, self).__init__(**kwargs)

        def build(self, input_shape):
            # no time axis in the input shape passed to RNN cells
            assert len(input_shape) == 2

            self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                          initializer='uniform',
                                          name='kernel')
            self.recurrent_kernel = self.add_weight(
                shape=(self.units, self.units),
                initializer='uniform',
                name='recurrent_kernel')
            self.built = True

        def call(self, inputs, states):
            prev_output = states[0]
            h = keras.backend.dot(inputs, self.kernel)
            output = h + keras.backend.dot(prev_output, self.recurrent_kernel)
            return output, [output]

        def get_config(self):
            config = {'units': self.units}
            base_config = super(MinimalRNNCell, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))

    # Test basic case.
    x = keras.Input((None, 5))
    cell = MinimalRNNCell(32)
    layer = recurrent.RNN(cell)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    # Test basic case serialization.
    x_np = np.random.random((6, 5, 5))
    y_np = model.predict(x_np)
    weights = model.get_weights()
    config = layer.get_config()
    with keras.utils.CustomObjectScope({'MinimalRNNCell': MinimalRNNCell}):
        layer = recurrent.RNN.from_config(config)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.set_weights(weights)
    y_np_2 = model.predict(x_np)
    assert_allclose(y_np, y_np_2, atol=1e-4)

    # Test stacking.
    cells = [MinimalRNNCell(8),
             MinimalRNNCell(12),
             MinimalRNNCell(32)]
    layer = recurrent.RNN(cells)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    # Test stacked RNN serialization.
    x_np = np.random.random((6, 5, 5))
    y_np = model.predict(x_np)
    weights = model.get_weights()
    config = layer.get_config()
    with keras.utils.CustomObjectScope({'MinimalRNNCell': MinimalRNNCell}):
        layer = recurrent.RNN.from_config(config)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.set_weights(weights)
    y_np_2 = model.predict(x_np)
    assert_allclose(y_np, y_np_2, atol=1e-4)


@rnn_cell_test
def test_builtin_rnn_cell_layer(cell_class):
    # Test basic case.
    x = keras.Input((None, 5))
    cell = cell_class(32)
    layer = recurrent.RNN(cell)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    # Test basic case serialization.
    x_np = np.random.random((6, 5, 5))
    y_np = model.predict(x_np)
    weights = model.get_weights()
    config = layer.get_config()
    layer = recurrent.RNN.from_config(config)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.set_weights(weights)
    y_np_2 = model.predict(x_np)
    assert_allclose(y_np, y_np_2, atol=1e-4)

    # Test stacking.
    cells = [cell_class(8),
             cell_class(12),
             cell_class(32)]
    layer = recurrent.RNN(cells)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))

    # Test stacked RNN serialization.
    x_np = np.random.random((6, 5, 5))
    y_np = model.predict(x_np)
    weights = model.get_weights()
    config = layer.get_config()
    layer = recurrent.RNN.from_config(config)
    y = layer(x)
    model = keras.models.Model(x, y)
    model.set_weights(weights)
    y_np_2 = model.predict(x_np)
    assert_allclose(y_np, y_np_2, atol=1e-4)


@pytest.mark.skipif((K.backend() in ['cntk', 'theano']),
                    reason='Not supported.')
def test_stacked_rnn_dropout():
    cells = [recurrent.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1),
             recurrent.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1)]
    layer = recurrent.RNN(cells)

    x = keras.Input((None, 5))
    y = layer(x)
    model = keras.models.Model(x, y)
    model.compile('sgd', 'mse')
    x_np = np.random.random((6, 5, 5))
    y_np = np.random.random((6, 3))
    model.train_on_batch(x_np, y_np)


def test_stacked_rnn_attributes():
    cells = [recurrent.LSTMCell(3),
             recurrent.LSTMCell(3, kernel_regularizer='l2')]
    layer = recurrent.RNN(cells)
    layer.build((None, None, 5))

    # Test regularization losses
    assert len(layer.losses) == 1

    # Test weights
    assert len(layer.trainable_weights) == 6
    cells[0].trainable = False
    assert len(layer.trainable_weights) == 3
    assert len(layer.non_trainable_weights) == 3

    x = keras.Input((None, 5))
    y = K.sum(x)
    cells[0].add_loss(y, inputs=x)


def test_stacked_rnn_compute_output_shape():
    cells = [recurrent.LSTMCell(3),
             recurrent.LSTMCell(6)]
    layer = recurrent.RNN(cells, return_state=True, return_sequences=True)
    output_shape = layer.compute_output_shape((None, timesteps, embedding_dim))
    expected_output_shape = [(None, timesteps, 6),
                             (None, 3),
                             (None, 3),
                             (None, 6),
                             (None, 6)]
    assert [tuple(s) for s in output_shape] == expected_output_shape

    # Test reverse_state_order = True for stacked cell.
    stacked_cell = recurrent.StackedRNNCells(
        cells, reverse_state_order=True)
    layer = recurrent.RNN(
        stacked_cell, return_state=True, return_sequences=True)
    output_shape = layer.compute_output_shape((None, timesteps, embedding_dim))
    expected_output_shape = [(None, timesteps, 6),
                             (None, 6),
                             (None, 6),
                             (None, 3),
                             (None, 3)]
    assert [tuple(s) for s in output_shape] == expected_output_shape


@rnn_test
def test_batch_size_equal_one(layer_class):
    inputs = Input(batch_shape=(1, timesteps, embedding_dim))
    layer = layer_class(units)
    outputs = layer(inputs)
    model = Model(inputs, outputs)
    model.compile('sgd', 'mse')
    x = np.random.random((1, timesteps, embedding_dim))
    y = np.random.random((1, units))
    model.train_on_batch(x, y)


def DISABLED_test_rnn_cell_with_constants_layer():

    class RNNCellWithConstants(keras.layers.Layer):

        def __init__(self, units, **kwargs):
            self.units = units
            self.state_size = units
            super(RNNCellWithConstants, self).__init__(**kwargs)

        def build(self, input_shape):
            if not isinstance(input_shape, list):
                raise TypeError('expects `constants` shape')
            [input_shape, constant_shape] = input_shape
            # will (and should) raise if more than one constant passed

            self.input_kernel = self.add_weight(
                shape=(input_shape[-1], self.units),
                initializer='uniform',
                name='kernel')
            self.recurrent_kernel = self.add_weight(
                shape=(self.units, self.units),
                initializer='uniform',
                name='recurrent_kernel')
            self.constant_kernel = self.add_weight(
                shape=(constant_shape[-1], self.units),
                initializer='uniform',
                name='constant_kernel')
            self.built = True

        def call(self, inputs, states, constants):
            [prev_output] = states
            [constant] = constants
            h_input = keras.backend.dot(inputs, self.input_kernel)
            h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
            h_const = keras.backend.dot(constant, self.constant_kernel)
            output = h_input + h_state + h_const
            return output, [output]

        def get_config(self):
            config = {'units': self.units}
            base_config = super(RNNCellWithConstants, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))

    # Test basic case.
    x = keras.Input((None, 5))
    c = keras.Input((3,))
    cell = RNNCellWithConstants(32)
    layer = recurrent.RNN(cell)
    y = layer(x, constants=c)
    model = keras.models.Model([x, c], y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(
        [np.zeros((6, 5, 5)), np.zeros((6, 3))],
        np.zeros((6, 32))
    )

    # Test basic case serialization.
    x_np = np.random.random((6, 5, 5))
    c_np = np.random.random((6, 3))
    y_np = model.predict([x_np, c_np])
    weights = model.get_weights()
    config = layer.get_config()
    custom_objects = {'RNNCellWithConstants': RNNCellWithConstants}
    with keras.utils.CustomObjectScope(custom_objects):
        layer = recurrent.RNN.from_config(config.copy())
    y = layer(x, constants=c)
    model = keras.models.Model([x, c], y)
    model.set_weights(weights)
    y_np_2 = model.predict([x_np, c_np])
    assert_allclose(y_np, y_np_2, atol=1e-4)

    # test flat list inputs
    with keras.utils.CustomObjectScope(custom_objects):
        layer = recurrent.RNN.from_config(config.copy())
    y = layer([x, c])
    model = keras.models.Model([x, c], y)
    model.set_weights(weights)
    y_np_3 = model.predict([x_np, c_np])
    assert_allclose(y_np, y_np_3, atol=1e-4)

    # Test stacking.
    cells = [recurrent.GRUCell(8),
             RNNCellWithConstants(12),
             RNNCellWithConstants(32)]
    layer = recurrent.RNN(cells)
    y = layer(x, constants=c)
    model = keras.models.Model([x, c], y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(
        [np.zeros((6, 5, 5)), np.zeros((6, 3))],
        np.zeros((6, 32))
    )

    # Test stacked RNN serialization.
    x_np = np.random.random((6, 5, 5))
    c_np = np.random.random((6, 3))
    y_np = model.predict([x_np, c_np])
    weights = model.get_weights()
    config = layer.get_config()
    with keras.utils.CustomObjectScope(custom_objects):
        layer = recurrent.RNN.from_config(config.copy())
    y = layer(x, constants=c)
    model = keras.models.Model([x, c], y)
    model.set_weights(weights)
    y_np_2 = model.predict([x_np, c_np])
    assert_allclose(y_np, y_np_2, atol=1e-4)


def DISABLED_test_rnn_cell_with_constants_layer_passing_initial_state():

    class RNNCellWithConstants(keras.layers.Layer):

        def __init__(self, units, **kwargs):
            self.units = units
            self.state_size = units
            super(RNNCellWithConstants, self).__init__(**kwargs)

        def build(self, input_shape):
            if not isinstance(input_shape, list):
                raise TypeError('expects constants shape')
            [input_shape, constant_shape] = input_shape
            # will (and should) raise if more than one constant passed

            self.input_kernel = self.add_weight(
                shape=(input_shape[-1], self.units),
                initializer='uniform',
                name='kernel')
            self.recurrent_kernel = self.add_weight(
                shape=(self.units, self.units),
                initializer='uniform',
                name='recurrent_kernel')
            self.constant_kernel = self.add_weight(
                shape=(constant_shape[-1], self.units),
                initializer='uniform',
                name='constant_kernel')
            self.built = True

        def call(self, inputs, states, constants):
            [prev_output] = states
            [constant] = constants
            h_input = keras.backend.dot(inputs, self.input_kernel)
            h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
            h_const = keras.backend.dot(constant, self.constant_kernel)
            output = h_input + h_state + h_const
            return output, [output]

        def get_config(self):
            config = {'units': self.units}
            base_config = super(RNNCellWithConstants, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))

    # Test basic case.
    x = keras.Input((None, 5))
    c = keras.Input((3,))
    s = keras.Input((32,))
    cell = RNNCellWithConstants(32)
    layer = recurrent.RNN(cell)
    y = layer(x, initial_state=s, constants=c)
    model = keras.models.Model([x, s, c], y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(
        [np.zeros((6, 5, 5)), np.zeros((6, 32)), np.zeros((6, 3))],
        np.zeros((6, 32))
    )

    # Test basic case serialization.
    x_np = np.random.random((6, 5, 5))
    s_np = np.random.random((6, 32))
    c_np = np.random.random((6, 3))
    y_np = model.predict([x_np, s_np, c_np])
    weights = model.get_weights()
    config = layer.get_config()
    custom_objects = {'RNNCellWithConstants': RNNCellWithConstants}
    with keras.utils.CustomObjectScope(custom_objects):
        layer = recurrent.RNN.from_config(config.copy())
    y = layer(x, initial_state=s, constants=c)
    model = keras.models.Model([x, s, c], y)
    model.set_weights(weights)
    y_np_2 = model.predict([x_np, s_np, c_np])
    assert_allclose(y_np, y_np_2, atol=1e-4)

    # verify that state is used
    y_np_2_different_s = model.predict([x_np, s_np + 10., c_np])
    with pytest.raises(AssertionError):
        assert_allclose(y_np, y_np_2_different_s, atol=1e-4)

    # test flat list inputs
    with keras.utils.CustomObjectScope(custom_objects):
        layer = recurrent.RNN.from_config(config.copy())
    y = layer([x, s, c])
    model = keras.models.Model([x, s, c], y)
    model.set_weights(weights)
    y_np_3 = model.predict([x_np, s_np, c_np])
    assert_allclose(y_np, y_np_3, atol=1e-4)


@rnn_test
def DISABLED_test_rnn_cell_identity_initializer(layer_class):
    inputs = Input(shape=(1, 2))
    layer = layer_class(2, recurrent_initializer='identity')
    layer(inputs)
    recurrent_kernel = layer.get_weights()[1]
    num_kernels = recurrent_kernel.shape[1] // recurrent_kernel.shape[0]
    assert np.array_equal(recurrent_kernel,
                          np.concatenate([np.identity(2)] * num_kernels, axis=1))


@pytest.mark.skipif(K.backend() == 'cntk', reason='Not supported.')
def test_inconsistent_output_state_size():

    class PlusOneRNNCell(keras.layers.Layer):
        """Add one to the input and state.

        This cell is used for testing state_size and output_size."""

        def __init__(self, num_unit, **kwargs):
            self.state_size = num_unit
            super(PlusOneRNNCell, self).__init__(**kwargs)

        def build(self, input_shape):
            self.output_size = input_shape[-1]

        def call(self, inputs, states):
            return inputs + 1, [states[0] + 1]

    batch = 32
    time_step = 4
    state_size = 5
    input_size = 6
    cell = PlusOneRNNCell(state_size)
    x = keras.Input((None, input_size))
    layer = recurrent.RNN(cell)
    y = layer(x)

    assert cell.state_size == state_size
    init_state = layer.get_initial_state(x)
    assert len(init_state) == 1
    if K.backend() != 'theano':
        # theano does not support static shape inference.
        assert K.int_shape(init_state[0]) == (None, state_size)

    model = keras.models.Model(x, y)
    model.compile(optimizer='rmsprop', loss='mse')
    model.train_on_batch(
        np.zeros((batch, time_step, input_size)),
        np.zeros((batch, input_size)))
    assert model.output_shape == (None, input_size)


if __name__ == '__main__':
    pytest.main([__file__])
