jax==0.3.2
jaxlib==0.3.2
optax>=0.1.0
chex==0.1.3
distrax>=0.1.2
tensorflow==2.8.0
tensorflow-probability==0.16.0
tqdm>=4.0.0
ml-collections==0.1.0

[dev]
black
isort
pylint
flake8

[docs]
furo==2020.12.30b24

[tests]
pytest
