jax
jaxlib>=0.1.55
mpi4py>=3.0.1
numpy

[dev]
pytest
