graphviz>=0.14.1
numpy
IPython
matplotlib

[all]
tensorflow
torch
jax
jaxlib

[jax]
jax
jaxlib

[tensorflow]
tensorflow

[torch]
torch
