jax>=0.2.21
numpyro>=0.8.0
dm-haiku>=0.0.5
matplotlib>=3.1
