Status of register_backward_hook


(James Proud) #1

It appears that backward hooks are currently broken as the recent issue and PR below discuss.


Is there a path towards fixing or appropriate operation for implementing a backward operation in lieu of using register_backward_hook?


(Alban D) #2

Hi,

Unfortunately it is a really tricky problem to solve without impacting users not using the hook and supporting most of the nn.Module behaviours.
It is very hard to make it work nicely with inplace operations and not even possible at the moment to forbid inplace ops if we don’t support it.
So there is some progress in cleaning this inplace ops and the backward hooks will come after.

Unfortunately, I have limited amount of time to do all the changes that are needed but if someone wants to work on this I can help.

At the moment, the workaround is to use register_hook directly on the specific Tensors that you want to change.


(James Proud) #3

For my use case, using register_hook would work as there is only one layer of gradients that I want to target and modify on the backward pass. What’s the best way of implementing this?


(Alban D) #4

Hi,

Could you give more details (or a small code sample) of what you want to do exactly? That way I can give you an exact code sample.


(James Proud) #5

I’ve included some sample code below. I am trying to mask the gradients of layer1 in this example. As you can see, I’m zeroing out the weights upon initialisation.

class MaskedLinear(nn.Module):
    # Currently unused. Intended for backward hook.
    def _zero_grad_mask(self, module, grad_input, grad_output):
        new_grad = Variable(
            torch.Tensor(grad_input[0].shape).cuda().masked_scatter_(self.indices_mask_tensor, grad_input[0].data)
        )
        return (grad_input[0], None, new_grad)

    def __init__(self, in_features, out_features, indices_mask):
        super(MaskedLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features).cuda()
        good_weights = self.linear.weight.data[indices_mask]
        self.linear.weight.data[:] = 0 # zero all weights out
        self.linear.weight.data[indices_mask] = good_weights # populate good weights
        self.indices_mask_tensor = torch.ByteTensor(out_features, in_features).cuda()
        self.indices_mask_tensor[indices_mask] = 1

    def forward(self, input):
        # Repeating original masking here is currently used, and works as a hack, but is incredibly slow.
        good_weights = self.linear.weight.data[self.indices_mask]
        self.linear.weight.data[:] = 0
        self.linear.weight.data[self.indices_mask] = good_weights
        return self.linear(input)

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer1 = MaskedLinear(in_dim, out_dim, mask)
        
    def forward(self, x):
        x = F.relu(self.layer1(x))
        return x

(Alban D) #6

Hi,

In that case, you can register the hook directly on self.linear.weight. See the sample below:

import torch
from torch import nn

l = nn.Linear(5, 5)

inp = torch.rand(3, 5)

print("original weights")
print(l.weight)

opt = torch.optim.SGD(l.parameters(), lr=0.01)

mask = torch.rand(5, 5).gt(0.5)
print("our mask")
print(mask)
def hook_fn(grad):
    # You are not allowed to modify inplace what is given !
    out = grad.clone()
    out[mask] = 0
    return out

# Registering the hook only once. This can be done in
# MaskedLinear's __init__ function
l.weight.register_hook(hook_fn)

print("Loss goes down")
for _ in range(50):
    loss = l(inp).abs().sum()
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss)
print("only some weights have changed")
print(l.weight)