dm_haiku==0.0.9
jax==0.4.8
jaxlib==0.4.8
jraph==0.0.6.dev0
numpy>=1.23.4
optax==0.1.3
