jax>=0.4.0
jaxlib>=0.4.0
optax>=0.1.3
tensorflow>=2.9.0
tensorflow_datasets>=4.3.0
numpy>=1.19.0
