jax>=0.2.9
mpi4py>=3.0.1
numpy

[dev]
pytest>=6
pytest-cov>=2.10.1
coverage[toml]>=5
pre-commit
black==21.6b0
flake8==3.9.2
