jax>=0.2.5
jaxlib>=0.1.57
numpy==1.18.5
multipledispatch==0.6.0
packaging==20.4
chex==0.0.4
tfp-nightly==0.12.0.dev20201123
optax==0.0.2

[dev]
black
isort
pylint
flake8

[docs]
furo==2020.12.30b24
nbsphinx==0.8.1
nb-black==1.0.7
matplotlib==3.3.3
sphinx-copybutton==0.3.1

[tests]
pytest
