jax>=0.4.6
jaxlib>=0.4.6
tensorflow_probability
chex
typing_extensions
equinox>=0.10.3
optax
wandb
