import jax.numpy as np
from jax import grad, jit, vmap
import jax
print(jax.local_device_count())
print(jax.devices()[0])
from functools import partial


def predict(params, inputs):
    for W, b in params:
        outputs = np.dot(inputs, W) + b
        inputs = np.tanh(outputs)
    return outputs


def logprob_fun(params, inputs, targets):
    preds = predict(params, inputs)
    return np.sum((preds - targets) ** 2)


grad_fun = jit(grad(logprob_fun))  # compiled gradient evaluation function
perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0)))  # fast per-example grads

print(perex_grads)