jax>=0.3.0
jaxlib>=0.3.0
chex
flax
numpy
pyyaml

[:python_version < "3.8"]
pickle5
