jax>=0.3.4
einops>=0.4.1
flax>=0.5.0
optax>=0.1.2
numpy>=1.21.2
datasets>=2.2.2
