chex==0.1.3
distrax>=0.1.2
jax==0.3.2
jaxlib==0.3.2
ml-collections==0.1.0
optax>=0.1.0
tensorflow-probability==0.16.0
tensorflow==2.8.0
tqdm>=4.0.0

[dev]
black
flake8
isort
pylint

[docs]
furo

[tests]
pytest
