How to Penalize Norm of End-to-End Jacobian

From my understanding, jax does as many backward as there are outputs just like @richard proposed. But they have the vmap operator that allows then to do this more efficiently than a for loop in python (even though the theoretical complexity is the same).

Do you know how their vmap works? I’m curious if it’s similar to what NestedTensor will eventually be or if they just do some program transformations to accomplish it

From my understanding: implement a batched version of every function: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#Batching

I guess this post is related


And you might like the following two repos


Unfortunately, it doesn’t support all types of networks (e.g. no batch norm), but convnets and all types of activations will work with just one backprop.

@albanD @richard how does the linked gist work? Previously, when I tried running backward() twice, the model wouldn’t learn, but switching to autograd.grad() fixed whatever problem existed. If I can use the linked gist and get backward() to work when run twice, then I might have a solution!

Hi,

The difference is that .backward() accumulate gradients in the .grad fields of the leafs. .grad() does not.
So you were most likely accumulating extra gradients when doing multiple call to .backward() if you were not calling .zero_grad() before the last one.

Hmm… I’m not sure this accords with what I see empirically. I’m using grad() followed by loss.backward() and this seems to change the training of the model compared with just running loss.backward(). If grad doesn’t accumulate gradients, then why does the outcome differ?

Maybe I don’t know what exactly you mean by leafs.

Sorry maybe that may not have be clear. The two different cases are:

opt.zero_grad()
loss = xxx(inputs) # Compute your loss
grads = xxx(loss) # Compute gradients wrt your loss
penalty = xxx(grads) # Compute the gradient penalty
final_loss = loss + penalty
final_loss.backward()
opt.step()

In the example above, you want to make your gradient step only for the gradients computed during final_loss.backward(). But if the computation of grads is done with .backward(create_graph=True), then you accumulate some extra gradients. You don’t do this if you compute grads with autograd.grad(create_graph=True).
So the two gradients when you step are different. That could explain your model training properly in one case but not the other