From my understanding, jax does as many backward as there are outputs just like @richard proposed. But they have the `vmap`

operator that allows then to do this more efficiently than a for loop in python (even though the theoretical complexity is the same).

# How to Penalize Norm of End-to-End Jacobian

Do you know how their vmap works? I’m curious if it’s similar to what NestedTensor will eventually be or if they just do some program transformations to accomplish it

From my understanding: implement a batched version of every function: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#Batching

I guess this post is related

And you might like the following two repos

Unfortunately, it doesn’t support all types of networks (e.g. no batch norm), but convnets and all types of activations will work with just one backprop.

@albanD @richard how does the linked gist work? Previously, when I tried running `backward()`

twice, the model wouldn’t learn, but switching to `autograd.grad()`

fixed whatever problem existed. If I can use the linked gist and get `backward()`

to work when run twice, then I might have a solution!

Hi,

The difference is that `.backward()`

accumulate gradients in the `.grad`

fields of the leafs. `.grad()`

does not.

So you were most likely accumulating extra gradients when doing multiple call to `.backward()`

if you were not calling `.zero_grad()`

before the last one.

Hmm… I’m not sure this accords with what I see empirically. I’m using `grad()`

followed by `loss.backward()`

and this seems to change the training of the model compared with just running `loss.backward()`

. If grad doesn’t accumulate gradients, then why does the outcome differ?

Maybe I don’t know what exactly you mean by leafs.

Sorry maybe that may not have be clear. The two different cases are:

```
opt.zero_grad()
loss = xxx(inputs) # Compute your loss
grads = xxx(loss) # Compute gradients wrt your loss
penalty = xxx(grads) # Compute the gradient penalty
final_loss = loss + penalty
final_loss.backward()
opt.step()
```

In the example above, you want to make your gradient step **only** for the gradients computed during final_loss.backward(). But if the computation of `grads`

is done with `.backward(create_graph=True)`

, then you accumulate some extra gradients. You don’t do this if you compute `grads`

with `autograd.grad(create_graph=True)`

.

So the two gradients when you step are different. That could explain your model training properly in one case but not the other