NumPy>=1.17.0
SciPy>=1.6.3
pandas>=1.2.4
torch>=1.7
jax>=0.4.2
flax>=0.6.4
optax>=0.1.4
tqdm>=4.0.0
