jax>=0.2.13
jaxlib>=0.1.65
tqdm

[cpu]
jax[cpu]>=0.2.13

[cuda101]
jax[cuda101]>=0.2.13

[cuda102]
jax[cuda102]>=0.2.13

[cuda110]
jax[cuda110]>=0.2.13

[cuda111]
jax[cuda111]>=0.2.13

[dev]
dm-haiku
flax
funsor==0.4.1
graphviz
jaxns==0.0.7
optax>=0.0.6
tensorflow_probability>=0.13

[doc]
ipython
nbsphinx>=0.8.5
sphinx
sphinx_rtd_theme
sphinx-gallery

[examples]
arviz
jupyter
matplotlib
pandas
seaborn
scikit-learn
wordcloud

[test]
black
flake8
isort>=5.0
pytest>=4.1
pyro-api>=0.1.1
scipy>=1.1

[tpu]
jax[tpu]>=0.2.13
