jax>=0.4
jaxlib>=0.4
multipledispatch
numpy
tqdm

[cpu]
jax[cpu]>=0.4

[cuda]
jax[cuda]>=0.4

[dev]
dm-haiku
flax
funsor>=0.4.1
graphviz
jaxns==1.0.0
optax>=0.0.6
pyyaml
tensorflow_probability>=0.17.0

[doc]
ipython
nbsphinx>=0.8.5
readthedocs-sphinx-search==0.1.0
sphinx
sphinx_rtd_theme
sphinx-gallery

[examples]
arviz
jupyter
matplotlib
pandas
seaborn
scikit-learn
wordcloud

[test]
importlib-metadata<5.0
black[jupyter]>=21.8b0
flake8
isort>=5.0
pytest>=4.1
pyro-api>=0.1.1
scipy<1.7,>=1.6

[tpu]
jax[tpu]>=0.4
