jax>=0.3.0
jaxlib>=0.3.0
chex
flax
numpy
