#
# Copyright (c) 2022 Bytedance Inc. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""RNN implementation in python
Originally copied from https://github.com/pytorch/pytorch/blob/v0.4.1/torch/nn/_functions/rnn.py
with following modification
    fusedBackend is removed
    CudnnRNN is removed
    Hack for ONNX in RNN() is removed
Only LSTM is quantized. Other paths are excluded in __all__
"""

import warnings
from torch.autograd import NestedIOFunction
from torch.nn import functional as F
import torch
import itertools
from functools import partial

__all__ = ["LSTMCell", "RNN"]


def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
    hy = F.relu(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh))
    return hy


def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
    hy = torch.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh))
    return hy


def LSTMCell(
    input,
    hidden,
    w_ih,
    w_hh,
    b_ih=None,
    b_hh=None,
    input_quantizer=None,
    weight_quantizer=None,
):
    """Quantized LSTM Cell

    The assumption is at inference time, only one fused gemm will be launched for one time step Weights of 4 gates
    are fused together, and activation from layer and recurrent paths are fused togather. ``input_quantizer`` will be
    applied on the fused activation tensor. And ``weight_quantizer`` will be applied on the fused weight tensor.
    """

    hx, cx = hidden
    if input_quantizer is not None:
        input, hx = input_quantizer(torch.cat([input, hx], 1)).split(
            [input.size()[1], hx.size()[1]], 1
        )
    if weight_quantizer is not None:
        w_ih, w_hh = weight_quantizer(torch.cat([w_ih, w_hh], 1)).split(
            [w_ih.size()[1], w_hh.size()[1]], 1
        )
    gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)

    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * torch.tanh(cy)

    return hy, cy


def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
    gi = F.linear(input, w_ih, b_ih)
    gh = F.linear(hidden, w_hh, b_hh)
    i_r, i_i, i_n = gi.chunk(3, 1)
    h_r, h_i, h_n = gh.chunk(3, 1)

    resetgate = torch.sigmoid(i_r + h_r)
    inputgate = torch.sigmoid(i_i + h_i)
    newgate = torch.tanh(i_n + resetgate * h_n)
    hy = newgate + inputgate * (hidden - newgate)

    return hy


def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True):

    num_directions = len(inners)
    total_layers = num_layers * num_directions

    def forward(
        input, hidden, weight, batch_sizes, input_quantizers, weight_quantizers
    ):
        assert len(weight) == total_layers
        next_hidden = []

        if lstm:
            hidden = list(zip(*hidden))

        for i in range(num_layers):
            all_output = []
            for j, inner in enumerate(inners):
                l = i * num_directions + j

                hy, output = inner(
                    input,
                    hidden[l],
                    weight[l],
                    batch_sizes,
                    input_quantizer=input_quantizers[l],
                    weight_quantizer=weight_quantizers[l],
                )
                next_hidden.append(hy)
                all_output.append(output)

            input = torch.cat(all_output, input.dim() - 1)

            if dropout != 0 and i < num_layers - 1:
                input = F.dropout(input, p=dropout, training=train, inplace=False)

        if lstm:
            next_h, next_c = zip(*next_hidden)
            next_hidden = (
                torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
                torch.cat(next_c, 0).view(total_layers, *next_c[0].size()),
            )
        else:
            next_hidden = torch.cat(next_hidden, 0).view(
                total_layers, *next_hidden[0].size()
            )

        return next_hidden, input

    return forward


def Recurrent(inner, reverse=False):
    def forward(input, hidden, weight, batch_sizes, input_quantizer, weight_quantizer):
        output = []
        steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
        for i in steps:
            hidden = inner(
                input[i],
                hidden,
                *weight,
                input_quantizer=input_quantizer,
                weight_quantizer=weight_quantizer
            )
            # hack to handle LSTM
            output.append(hidden[0] if isinstance(hidden, tuple) else hidden)

        if reverse:
            output.reverse()
        output = torch.cat(output, 0).view(input.size(0), *output[0].size())

        return hidden, output

    return forward


def variable_recurrent_factory(inner, reverse=False):
    if reverse:
        return VariableRecurrentReverse(inner)
    else:
        return VariableRecurrent(inner)


def VariableRecurrent(inner):
    def forward(input, hidden, weight, batch_sizes, input_quantizer, weight_quantizer):

        output = []
        input_offset = 0
        last_batch_size = batch_sizes[0]
        hiddens = []
        flat_hidden = not isinstance(hidden, tuple)
        if flat_hidden:
            hidden = (hidden,)
        for batch_size in batch_sizes:
            step_input = input[input_offset : input_offset + batch_size]
            input_offset += batch_size

            dec = last_batch_size - batch_size
            if dec > 0:
                hiddens.append(tuple(h[-dec:] for h in hidden))
                hidden = tuple(h[:-dec] for h in hidden)
            last_batch_size = batch_size

            if flat_hidden:
                hidden = (
                    inner(
                        step_input,
                        hidden[0],
                        *weight,
                        input_quantizer=input_quantizer,
                        weight_quantizer=weight_quantizer
                    ),
                )
            else:
                hidden = inner(
                    step_input,
                    hidden,
                    *weight,
                    input_quantizer=input_quantizer,
                    weight_quantizer=weight_quantizer
                )

            output.append(hidden[0])
        hiddens.append(hidden)
        hiddens.reverse()

        hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens))
        assert hidden[0].size(0) == batch_sizes[0]
        if flat_hidden:
            hidden = hidden[0]
        output = torch.cat(output, 0)

        return hidden, output

    return forward


def VariableRecurrentReverse(inner):
    def forward(input, hidden, weight, batch_sizes, input_quantizer, weight_quantizer):
        output = []
        input_offset = input.size(0)
        last_batch_size = batch_sizes[-1]
        initial_hidden = hidden
        flat_hidden = not isinstance(hidden, tuple)
        if flat_hidden:
            hidden = (hidden,)
            initial_hidden = (initial_hidden,)
        hidden = tuple(h[: batch_sizes[-1]] for h in hidden)
        for i in reversed(range(len(batch_sizes))):
            batch_size = batch_sizes[i]
            inc = batch_size - last_batch_size
            if inc > 0:
                hidden = tuple(
                    torch.cat((h, ih[last_batch_size:batch_size]), 0)
                    for h, ih in zip(hidden, initial_hidden)
                )
            last_batch_size = batch_size
            step_input = input[input_offset - batch_size : input_offset]
            input_offset -= batch_size

            if flat_hidden:
                hidden = (
                    inner(
                        step_input,
                        hidden[0],
                        *weight,
                        input_quantizer=input_quantizer,
                        weight_quantizer=weight_quantizer
                    ),
                )
            else:
                hidden = inner(
                    step_input,
                    hidden,
                    *weight,
                    input_quantizer=input_quantizer,
                    weight_quantizer=weight_quantizer
                )
            output.append(hidden[0])

        output.reverse()
        output = torch.cat(output, 0)
        if flat_hidden:
            hidden = hidden[0]
        return hidden, output

    return forward


def AutogradRNN(
    mode,
    input_size,
    hidden_size,
    num_layers=1,
    batch_first=False,
    dropout=0,
    train=True,
    bidirectional=False,
    variable_length=False,
    dropout_state=None,
    flat_weight=None,
    input_quantizers=None,
    weight_quantizers=None,
):

    if mode == "RNN_RELU":
        cell = RNNReLUCell
    elif mode == "RNN_TANH":
        cell = RNNTanhCell
    elif mode == "LSTM":
        cell = LSTMCell
    elif mode == "GRU":
        cell = GRUCell
    else:
        raise Exception("Unknown mode: {}".format(mode))

    rec_factory = variable_recurrent_factory if variable_length else Recurrent

    if bidirectional:
        layer = (rec_factory(cell), rec_factory(cell, reverse=True))
    else:
        layer = (rec_factory(cell),)

    func = StackedRNN(layer, num_layers, (mode == "LSTM"), dropout=dropout, train=train)

    def forward(
        input, weight, hidden, batch_sizes, input_quantizers, weight_quantizers
    ):
        if batch_first and not variable_length:
            input = input.transpose(0, 1)

        nexth, output = func(
            input, hidden, weight, batch_sizes, input_quantizers, weight_quantizers
        )

        if batch_first and not variable_length:
            output = output.transpose(0, 1)

        return output, nexth

    return forward


def RNN(*args, **kwargs):
    def forward(input, *fargs, **fkwargs):
        func = AutogradRNN(*args, **kwargs)
        return func(input, *fargs, **fkwargs)

    return forward
