Problems on register_backward_hook

Hi. We have such sample code:

class xxx:
    def __init__(self, model):
        self.model = model
        self.gradients = None
        self.model.eval()
        self.hook_layers()

    def hook_layers(self):
        def hook_function(module, grad_in, grad_out):
            self.gradients = grad_in[0]
        first_layer = list(self.model.children())[0] . # Get the 1st layer.
        first_layer.register_backward_hook(hook_function)

    def generate_gradients(self, input_image, target_class):
        # Forward pass
        model_output = self.model(input_image)
        model_output = model_output.reshape(-1)
        input_image.requires_grad = True
        self.model.zero_grad()
        one_hot_output = torch.FloatTensor(model_output.size()[-1]).zero_()
        one_hot_output[target_class] = 1
        # Backward pass
        model_output.backward(gradient=one_hot_output)

        gradients_as_arr = self.gradients.data.numpy()[0]
        return gradients_as_arr

input_image is just pil image preprocessed by torchvision.

When I run the code, the error says self.gradients.data does not exist. Actually, the format of grad_in is like (None, , None). May I know what do these dimensions mean in grad_in? How to correctly save the gradient of the first layer.

Many thanks~

Hi,

input_image in your example should be a Tensor. Otherwise, setting requires_grad on it will have no effect.

Also you should not use .data anymore. You can use .detach() here to replace it.

Are you trying to get the gradient for the input of your network? If that case, since you set require_grad=True on input_image which is a leaf Tensor, after the call to .backward, you can simply access it via input_image.grad.

Finally as you can see in the doc for register_backward_hook, it won’t work with all nn.Modules. So you might want to use register_hook directly on a Tensor if you want the gradients for that Tensor.
Or if you have access to that Tensor during the forward, simply call .retain_grad() on it for its .grad field to be populated during the backward.