numpy
scipy
matplotlib
seaborn
jaxlib>=0.4.1
jax>=0.4.1
optax
flax
backends>=1.4.32
mlkernels>=0.3.6
