jax>=0.2.9
mpi4py>=3.0.1
numpy
