einops==0.4
equinox>=0.5
jax>=0.3.4
jaxlib>=0.1
optax
numpy
