numpy>=1.16.0
scipy>=1.2.1
autograd>=1.3

[jax]
dm-haiku>=0.0.5
jax>=0.2.28
jaxlib>=0.1.74

[pytorch]
torch>=1.8.0

[tests]
pytest-xdist
numpy==1.21.5
scipy==1.7.3
autograd==1.3
absl-py==0.15.0
dm-tree==0.1.6
dm-haiku==0.0.5
jax==0.2.28
jaxlib==0.1.74
tensorflow==2.7.0
torch==1.13.1

[tf]
tensorflow>=1.15
