Metadata-Version: 2.1
Name: fast_rnnt
Version: 1.0
Summary: Fast and memory-efficient RNN-T loss.
Home-page: https://github.com/danpovey/fast_rnnt
Author: Dan Povey
Author-email: dpovey@gmail.com
License: Apache licensed, as found in the LICENSE file
Platform: UNKNOWN
Classifier: Programming Language :: C++
Classifier: Programming Language :: Python
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Description-Content-Type: text/markdown
License-File: LICENSE


This project implements a method for faster and more memory-efficient RNN-T loss computation, called `pruned rnnt`.

Note: There is also a fast RNN-T loss implementation in [k2](https://github.com/k2-fsa/k2) project, which shares the same code here. We make `fast_rnnt` a stand-alone project in case someone wants only this rnnt loss.

## How does the pruned-rnnt work ?

We first obtain pruning bounds for the RNN-T recursion using a simple joiner network that is just an addition of the encoder and decoder, then we use those pruning bounds to evaluate the full, non-linear joiner network.

The picture below display the gradients (obtained by `rnnt_loss_simple` with `return_grad=true`) of lattice nodes, at each time frame, only a small set of nodes have a non-zero gradient, which justifies the pruned RNN-T loss, i.e., putting a limit on the number of symbols per frame.

<img src="https://user-images.githubusercontent.com/5284924/158116784-4dcf1107-2b84-4c0c-90c3-cb4a02f027c9.png" width="900" height="250" />

> This picture is taken from [here](https://github.com/k2-fsa/icefall/pull/251)

## Installation

You can install it via `pip`:

```
pip install fast_rnnt
```

You can also install from source:

```
git clone https://github.com/danpovey/fast_rnnt.git
cd fast_rnnt
python setup.py install
```

To check that `fast_rnnt` was installed successfully, please run

```
python3 -c "import fast_rnnt; print(fast_rnnt.__version__)"
```

which should print the version of the installed `fast_rnnt`, e.g., `1.0`.


### How to display installation log ?

Use

```
pip install --verbose fast_rnnt
```

### How to reduce installation time ?

Use

```
export FT_MAKE_ARGS="-j"
pip install --verbose fast_rnnt
```

It will pass `-j` to `make`.

### Which version of PyTorch is supported ?

It has been tested on PyTorch >= 1.5.0.

Note: The cuda version of the Pytorch should be the same as the cuda version in your environment,
or it will cause a compilation error.


### How to install a CPU version of `fast_rnnt` ?

Use

```
export FT_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=Release -DFT_WITH_CUDA=OFF"
export FT_MAKE_ARGS="-j"
pip install --verbose fast_rnnt
```

It will pass `-DCMAKE_BUILD_TYPE=Release -DFT_WITH_CUDA=OFF` to `cmake`.

### Where to get help if I have problems with the installation ?

Please file an issue at <https://github.com/danpovey/fast_rnnt/issues>
and describe your problem there.


## Usage

### For rnnt_loss_simple

This is a simple case of the RNN-T loss, where the joiner network is just
addition.

Note: termination_symbol plays the role of blank in other RNN-T loss implementations, we call it termination_symbol as it terminates symbols of current frame.

```python
am = torch.randn((B, T, C), dtype=torch.float32)
lm = torch.randn((B, S + 1, C), dtype=torch.float32)
symbols = torch.randint(0, C, (B, S))
termination_symbol = 0

boundary = torch.zeros((B, 4), dtype=torch.int64)
boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames

loss = fast_rnnt.rnnt_loss_simple(
    lm=lm,
    am=am,
    symbols=symbols,
    termination_symbol=termination_symbol,
    boundary=boundary,
    reduction="sum",
)
```

### For rnnt_loss_smoothed

The same as `rnnt_loss_simple`, except that it supports `am_only` & `lm_only` smoothing
that allows you to make the loss-function one of the form:

          lm_only_scale * lm_probs +
          am_only_scale * am_probs +
          (1-lm_only_scale-am_only_scale) * combined_probs

where `lm_probs` and `am_probs` are the probabilities given the lm and acoustic model independently.

```python
am = torch.randn((B, T, C), dtype=torch.float32)
lm = torch.randn((B, S + 1, C), dtype=torch.float32)
symbols = torch.randint(0, C, (B, S))
termination_symbol = 0

boundary = torch.zeros((B, 4), dtype=torch.int64)
boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames

loss = fast_rnnt.rnnt_loss_smoothed(
    lm=lm,
    am=am,
    symbols=symbols,
    termination_symbol=termination_symbol,
    lm_only_scale=0.25,
    am_only_scale=0.0
    boundary=boundary,
    reduction="sum",
)
```

### For rnnt_loss_pruned

`rnnt_loss_pruned` can not be used alone, it needs the gradients returned by `rnnt_loss_simple/rnnt_loss_smoothed` to get pruning bounds.

```python
am = torch.randn((B, T, C), dtype=torch.float32)
lm = torch.randn((B, S + 1, C), dtype=torch.float32)
symbols = torch.randint(0, C, (B, S))
termination_symbol = 0

boundary = torch.zeros((B, 4), dtype=torch.int64)
boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames

# rnnt_loss_simple can be also rnnt_loss_smoothed
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
    lm=lm,
    am=am,
    symbols=symbols,
    termination_symbol=termination_symbol,
    boundary=boundary,
    reduction="sum",
    return_grad=True,
)
s_range = 5  # can be other values
ranges = fast_rnnt.get_rnnt_prune_ranges(
    px_grad=px_grad,
    py_grad=py_grad,
    boundary=boundary,
    s_range=s_range,
)

am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges)

logits = model.joiner(am_pruned, lm_pruned)
pruned_loss = fast_rnnt.rnnt_loss_pruned(
    logits=logits,
    symbols=symbols,
    ranges=ranges,
    termination_symbol=termination_symbol,
    boundary=boundary,
    reduction="sum",
)
```

You can also find recipes [here](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless) that uses `rnnt_loss_pruned` to train a model.


### For rnnt_loss

The `unprund rnnt_loss` is the same as `torchaudio rnnt_loss`, it produces same output as torchaudio for the same input.

```python
logits = torch.randn((B, S, T, C), dtype=torch.float32)
symbols = torch.randint(0, C, (B, S))
termination_symbol = 0

boundary = torch.zeros((B, 4), dtype=torch.int64)
boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames

loss = fast_rnnt.rnnt_loss(
    logits=logits,
    symbols=symbols,
    termination_symbol=termination_symbol,
    boundary=boundary,
    reduction="sum",
)
```


## Benchmarking

The [repo](https://github.com/csukuangfj/transducer-loss-benchmarking) compares the speed and memory usage of several transducer losses, the summary in the following table is taken from there, you can check the repository for more details.

Note: As we declared above, `fast_rnnt` is also implemented in [k2](https://github.com/k2-fsa/k2) project, so `k2` and `fast_rnnt` are equivalent in the benchmarking.

|Name	               |Average step time (us) | Peak memory usage (MB)|
|--------------------|-----------------------|-----------------------|
|torchaudio          |601447                 |12959.2                |
|fast_rnnt(unpruned) |274407                 |15106.5                |
|fast_rnnt(pruned)   |38112                  |2647.8                 |
|optimized_transducer|567684                 |10903.1                |
|warprnnt_numba      |229340                 |13061.8                |
|warp-transducer     |210772                 |13061.8                |


