Applying an input-dependent transformation to the gradients during the backward pass

Hi everyone,

I have a specific usage of autograd and I want to know your opinion about the efficient way to make it work.

For context, I want to apply a transformation to the gradients before executing the optimization step. This transformation multiplies the gradient of the loss with the gradient of the prediction function.

For a parameter p, a prediction function (neural network) f, and a loss L, the transformation is the following:

grad_p = d(f_p)/d(p) * d(L_p)/d(p)

The problem is that both gradients are input-dependent, which means that the transformation must be applied before the gradient accumulation step during the backward pass.

The naive solution I can think of is using a batch_size of 1, in this case I can proceed as follows:

# batch input X, batch groundtruth y
prediction = model(X)
gradient = []
For param in model.parameters():
loss = loss_fn(prediction, y)  # MSE
For i, param in enumerate(model.parameters()):
    param.grad *= gradients[i]

Now this is very compute-inefficient and I would like to ask you if there is a way to do it by batch applying the transformation on the fly during the backward pass?

I have thought about registering a backward hook but it’s not clear to me how to pass the first gradient to the hook.

Thank you.

That is an interesting question, and thanks for the clear code example.

Using a backward hook as you suggest is how I would approach it. I’ve not tested it, but have you tried something like adding:

loss.register_hook(lambda param_grad: param_grad * gradients.pop(0))

Right before loss.backward()? What I do here is consume the first list of gradients in the same order (this relies on both backward passes to occur in the same order).

Hi yiftach,

Thank you for your answer.

In the case of a batch size > 1, and if we first compute the loss of the prediction prediction.backward(retain_graph=True), there will be gradient accumulation across the batch.

Then if we apply the hook to the loss as you suggested we will again do the multiplication (call the hook) after the accumulation across the batch.

So for a set of batches B, parameter \theta, model f_\theta, and loss function L, what we would have computed at the end is the first equation, while the expected behavior is the second equation:

Am I right? or this not the usage of the hook that you suggested?