jax>=0.2.21
jmp>=0.0.2

[test]
chex
dm-haiku
fire
opax
pytest
pytype
tqdm
