jax
jaxlib>=0.1.55
mpi4py>=3.0.1
numpy

[dev]
pytest
black
flake8==3.8.3
pre-commit>=2
