jax>=0.2.13
jaxlib>=0.1.65
multipledispatch
numpy
tqdm

[cpu]
jax[cpu]>=0.2.13

[cuda]
jax[cuda]>=0.2.13

[dev]
dm-haiku
flax
funsor>=0.4.1
graphviz
jaxns==0.0.7
optax>=0.0.6
tensorflow_probability>=0.15.0

[doc]
ipython
nbsphinx>=0.8.5
readthedocs-sphinx-search==0.1.0
sphinx
sphinx_rtd_theme
sphinx-gallery

[examples]
arviz
jupyter
matplotlib
pandas
seaborn
scikit-learn
wordcloud

[test]
black[jupyter]>=21.8b0
flake8
isort>=5.0
pytest>=4.1
pyro-api>=0.1.1
scipy<1.7,>=1.6

[tpu]
jax[tpu]>=0.2.13
