jax
dm-haiku
optax
numpy
jax-unirep
