Is it possible to regularize gradients without multiple graph traversal?

I am currently trying to regularize the training of a network using the norm of the gradients of one of the output relative to one of its inputs.

The issue I am facing currently is that as I have to compute the gradients in the loss, I have to call autograd.grad which triggers one graph traversal and then call loss.backward() to update the weights of my network.

Currently, this doubles the training time because I spend a lot of time in the backward. Is there a better way to do it ?

Hi @Soudini,

You could create a custom optimiser and re-use those gradient terms within the optim itself. But do share a minimal reproducible example to highlight your problem!

Hi, thank you for the response.

What I want to do is something like this:

inputs = torch.rand(10, 10)
pred = model(inputs)

# Here is my first graph traversal
grad = grad(pred.sum(), inputs, retain_graph=True)[0][:,:5] #let's say that I only need the gradients relative to the 5 first inputs

# Here is my second graph traversal
loss = grad.norm(dim=1).sum() + (pred - target).norm() 
loss.backward()
optim.step()

Currently this takes around twice the time to perform compared to not regularizing the gradient.

I will also try your suggestion and come back to you.