Status of register_backward_hook

It appears that backward hooks are currently broken as the recent issue and PR below discuss.


Is there a path towards fixing or appropriate operation for implementing a backward operation in lieu of using register_backward_hook?

Hi,

Unfortunately it is a really tricky problem to solve without impacting users not using the hook and supporting most of the nn.Module behaviours.
It is very hard to make it work nicely with inplace operations and not even possible at the moment to forbid inplace ops if we don’t support it.
So there is some progress in cleaning this inplace ops and the backward hooks will come after.

Unfortunately, I have limited amount of time to do all the changes that are needed but if someone wants to work on this I can help.

At the moment, the workaround is to use register_hook directly on the specific Tensors that you want to change.

For my use case, using register_hook would work as there is only one layer of gradients that I want to target and modify on the backward pass. What’s the best way of implementing this?

Hi,

Could you give more details (or a small code sample) of what you want to do exactly? That way I can give you an exact code sample.

I’ve included some sample code below. I am trying to mask the gradients of layer1 in this example. As you can see, I’m zeroing out the weights upon initialisation.

class MaskedLinear(nn.Module):
    # Currently unused. Intended for backward hook.
    def _zero_grad_mask(self, module, grad_input, grad_output):
        new_grad = Variable(
            torch.Tensor(grad_input[0].shape).cuda().masked_scatter_(self.indices_mask_tensor, grad_input[0].data)
        )
        return (grad_input[0], None, new_grad)

    def __init__(self, in_features, out_features, indices_mask):
        super(MaskedLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features).cuda()
        good_weights = self.linear.weight.data[indices_mask]
        self.linear.weight.data[:] = 0 # zero all weights out
        self.linear.weight.data[indices_mask] = good_weights # populate good weights
        self.indices_mask_tensor = torch.ByteTensor(out_features, in_features).cuda()
        self.indices_mask_tensor[indices_mask] = 1

    def forward(self, input):
        # Repeating original masking here is currently used, and works as a hack, but is incredibly slow.
        good_weights = self.linear.weight.data[self.indices_mask]
        self.linear.weight.data[:] = 0
        self.linear.weight.data[self.indices_mask] = good_weights
        return self.linear(input)

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer1 = MaskedLinear(in_dim, out_dim, mask)
        
    def forward(self, x):
        x = F.relu(self.layer1(x))
        return x
1 Like

Hi,

In that case, you can register the hook directly on self.linear.weight. See the sample below:

import torch
from torch import nn

l = nn.Linear(5, 5)

inp = torch.rand(3, 5)

print("original weights")
print(l.weight)

opt = torch.optim.SGD(l.parameters(), lr=0.01)

mask = torch.rand(5, 5).gt(0.5)
print("our mask")
print(mask)
def hook_fn(grad):
    # You are not allowed to modify inplace what is given !
    out = grad.clone()
    out[mask] = 0
    return out

# Registering the hook only once. This can be done in
# MaskedLinear's __init__ function
l.weight.register_hook(hook_fn)

print("Loss goes down")
for _ in range(50):
    loss = l(inp).abs().sum()
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss)
print("only some weights have changed")
print(l.weight)
1 Like

Hi,

I have a complex module where I want to have gradient back prop over a specific region of input. Can I do this in the forward function on the input?

def forward(self, input):
    self.loss = nn.MSELoss(input, self.target)
   
    def hook_fn(grad):
         return grad * self.mask // where mask is a binary mask defined to be a subset of input 
    input.register_hook(hook_fn)
   
    return input
1 Like

Hi,

Yes registering a hook on the Tensor of interest during the forward is the way to go.

1 Like

Hi, thanks for your reply. I have a question inside the hook_fn.
Why you use grad.clone() inside the hook function, can’t we just use

def hook_fn(grad):
    grad[mask] = 0
    return grad

I notice your comment “you are not allowed to modify inplace what is given”
Does this mean we need to clone gradient inside register_hook functions for tensors?
Since the tow functions have the same return values.
Thanks a lot.

Hi,

Yes you have to clone it if you plan on doing inplace changes!
This grad Tensor might be used in other places in the autograd and changing its value inplace can lead to silently wrong gradients!

2 Likes