How to modify Conv2d input gradients using backward hook?

I want to change the gradients during a backward pass for each Conv2d modules, so I’m trying to figure out how to change the input_gradiens using the backward hook, but cannot figure out what to return from the hook function in order to change the input_gradients.


def backward_hook(module, input_grads, output_grads):
        if isinstance(module, nn.Conv2d):
            input_grads = ..... # changing the input_grads here
            return input_grads, output_grads

    for module in model.modules():
        module.register_backward_hook(backward_hook)

And I’m getting this error on calling the .backward()

TypeError: expected Variable, but hook returned 'tuple'

I’ve tried all permutations and combinations 🥲 such as:

return input_grads[0], output_grads[0]
return input_grads, output_grads[0]
return input_grads,
return input_grads[0]

etc but still no luck.

What should I return in the case of Conv2d? Also, I could not find any documentation related to this, so very difficult to get some results.
I’ve also noticed that what we need to return in the backward hook function also depends on the type of the module, for instance the code below runs just fine with the nn.ReLU:


def backward_hook(module, input_grads, output_grads):
        if isinstance(module, nn.ReLU):
            input_grads = ..... # changing the input_grads here
            return input_grads, # so with relu it seems we don't need to return output_grads, 
            # but this does not work with Conv2d

    for module in model.modules():
        module.register_backward_hook(backward_hook)

I don’t think you should return the output_grads, as they were already passed to the module and won’t be used anymore. Instead you should return all input gradients, which would be passed to the “previous” layer (previous in the sense of the forward execution). Also use register_full_backward_hook, as register_backward_hook might not be working correctly.
From the docs:

The grad_input and grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations. grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in grad_input and grad_output will be None for all non-Tensor arguments.

Here is a minimal example:

def backward_hook(m, input_gradients, output_gradients):
    print('input_gradients {}'.format(input_gradients))
    print('output_gradients {}'.format(output_gradients))
    input_gradients = (torch.ones_like(input_gradients[0]), )
    return input_gradients

conv = nn.Conv2d(1, 1, 3)
conv.register_full_backward_hook(backward_hook)

x = torch.randn(1, 1, 3, 3).requires_grad_()
out = conv(x)
out.mean().backward()
print(x.grad) # ones
1 Like

Thanks a lot @ptrblck, always to the rescue :slight_smile:

Hi,

I have a question regarding the shape of input_grads and output_gradients.
I’ve checked on some models, but in all cases, len(input_gradients) and len(output_gradients) are always 1. Is it the case for all input_gradients/output_gradients?
If not, can you show me one example where input_gradients/output_gradients have more than 1 entry?
Thanks!

No, that’s not always the case and depend on the number of input and output arguments.
Here is a small example showing the usage of 3 input and 2 output arguments where one input argument is a constant:

def backward_hook(m, input_gradients, output_gradients):
    print('input_gradients {}'.format(input_gradients))
    print('output_gradients {}'.format(output_gradients))
    return input_gradients

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, 1)
        self.fc2 = nn.Linear(2, 2)
        
    def forward(self, x, y, static):
        x = self.fc1(x) + static
        y = self.fc2(y)
        return x, y

model = MyModule()
model.register_full_backward_hook(backward_hook)

x = torch.randn(1, 1).requires_grad_()
y = torch.randn(1, 2).requires_grad_()
static = 1.
out_x, out_y = model(x, y, static)
loss = out_x + out_y
loss.mean().backward()
# input_gradients (tensor([[-0.2857]]), tensor([[-0.1686,  0.3697]]), None)
# output_gradients (tensor([[1.]]), tensor([[0.5000, 0.5000]]))

and thus its input_gradient is None.
You can play around with this code snippet and add more input/output arguments.

1 Like

This makes sense. Thanks a lot!