Backward hook uncaught in grad of grad

Hi pytorch friends!

I’m trying to implement fast gradient penalty using forward and backward hooks but found that for gradients of gradients, hooks show a slightly aberrant behavior. Hopefully, you can help me find where I go wrong.

I tried to construct a minimal example that shows the behavior. Let me start by hooking in a simple linear model with one parameter.

import numpy as np
import torch
from torch import nn
from torch import autograd

torch.manual_seed(42)

class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(1, 1, bias=False)
        self.lin.weight = nn.Parameter(torch.tensor([[2.0]]))

    def forward(self, x):
        # Switch comments in the following lines for an aberrant behavior.
        return self.lin(x)
        # return torch.log(torch.exp(self.lin(x)))

model = Model()

# Add hooks to construct per-example gradients
gradients = {}
activations = {}
def forward_hook(layer, inputs, outputs):
    activations[layer] = inputs[0].detach()

def backward_hook(layer, grad_input, grad_output):
    A = activations[layer]
    B = grad_output[0].detach()
    gradients[layer] = torch.einsum("ni,nj->nij", B, A)

model.lin.register_forward_hook(forward_hook)
model.lin.register_backward_hook(backward_hook)

The backward hook computes the gradient specifically for a layer of type nn.Linear. Let’s feed the model with a batch of 2 input examples, and compute parameter gradients. We also check that the per-example gradients are computed correctly. But we’ll comment that out later.

x = torch.FloatTensor([2, 1])[..., None]
x.requires_grad_(True)
loss = model(x).sum()
# We'll comment this out later
loss.backward()
for layer, g in gradients.items():
    assert torch.allclose(layer.weight.grad, g.sum(dim=0))
# .. until here.

So the hooks seem to work, and I was pretty happy with myself having successfully stolen the trick of @Yaroslav_Bulatov and his colleagues.

However, I need to actually compute gradient of loss with respect to the input data x, and construct a cost function from that.

grads = autograd.grad(
    outputs=loss, inputs=x,
    retain_graph=True, create_graph=True, only_inputs=True
)[0]
grad_loss = (grads ** 2).sum()
model.zero_grad()
grad_loss.backward()

Unfortunately the test fails now:

grad_loss = (grads ** 2).sum()
model.zero_grad()
grad_loss.backward()
for layer, g in gradients.items():
    assert layer.weight.grad == 8, layer.weight.grad
    assert g.sum(dim=0) == 3, g

The situation gets even more confusing when you change the comments in Model.forward. Note that we take log(exp(x)) which should do anything, but subsequently, the hooked gradients are all zero.

Here’s a notebook.

Please help :confused:

Hi,

As mentionned in the doc for .register_backward_hook(), they are fairly broken right now and will give aberrant behavior :’( We are working on fixing it but it proves quite challenging.

In particular, the grad_input value will be wrong as soon as your nn.Module does more than one operation.

If you have a single Tensor for input and outputs, you should use tensor.register_hook() that is called on a Tensor in combination with the nn.Module.register_forward_hook() to get the behavior you want:

This would look like this (I did not run the code so there may be typos)

# Add hooks to construct per-example gradients
gradients = {}
activations = {}
def forward_hook(layer, inputs, outputs):
    activations[layer] = inputs[0].detach()

    output_grads = []
    def hook_output(grad_output):
        output_grads[0] = grad_output.detach()

    def hook_input(grad_input):
        A = activations[layer]
        B = output_grads[0]
        gradients[layer] = torch.einsum("ni,nj->nij", B, A)

    inputs[0].register_hook(hook_input)
    outputs[0].register_hook(hook_output)

model.lin.register_forward_hook(forward_hook)
2 Likes

I explored your solution, but without much luck so far. Indeed, there’s also a typo: outputs[0].register... -> outputs.register.... Thanks though; let me how I can help with a module fix.

Shouldn’t activations of layer be the outputs variable?
P.S. Since we are using the word activation, it should require applying relu or other activation.(Sorry, ignore this).

1 Like