Metadata-Version: 2.1
Name: jax-tqdm
Version: 0.1.0
Summary: Tqdm progress bar for JAX scans and loops
Home-page: https://github.com/jeremiecoullon/jax-tqdm
License: MIT
Keywords: jax,tqdm
Author: Jeremie Coullon
Author-email: jeremie.coullon@gmail.com
Requires-Python: >=3.8,<4.0
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Requires-Dist: jax (>=0.3.5)
Requires-Dist: tqdm (>=4.64.1,<5.0.0)
Project-URL: Repository, https://github.com/jeremiecoullon/jax-tqdm
Description-Content-Type: text/markdown

# JAX-tqdm

Add a tqdm progress bar to your JAX scans and loops.

The code is explained in this [blog post](https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/).

## Example usage

### in `jax.lax.scan`

```python
from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp

n = 10_000

@scan_tqdm(n)
def step(carry, x):
    return carry + 1, carry + 1

last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))
```


### in `jax.lax.fori_loop`

```python
from jax_tqdm import loop_tqdm
from jax import lax

n = 10_000

@loop_tqdm(n)
def step(i, val):
    return val + 1

last_number = lax.fori_loop(0, n, step, 0)
```

