Masking gradient updates to weights using register_hook

I am multiplying a mask to my weight matrix during the forward prop to sparsify my network. However, during backprop, I want the gradient updates to be zero for the masked weights (mask[i] = 0).

Is there any way to do that using register_hook?

Here’s a toy code I have used, but am getting an error.

mask = torch.tensor([0,1,0,1,0])

a.requires_grad = True

b = 3*a

b.retain_grad()

b.register_hook(mask)  

b.mean().backward() 

print(a.grad, b.grad)

Error:

--->b.mean().backward()
    TypeError: 'Tensor' object is not callable

I have tried:

b.register_hook(lambda grad:grad*mask)

which works. Is there any other better way?

1 Like

Hi,

A hook has to be a callable function, not a Tensor. So the lambda is the way to go.

1 Like

Thanks a lot for your response. I have another question. Lets say I am introducing a parameter c like this:

a = torch.ones(5)
a.requires_grad = True
mask = torch.tensor([1,0,0,1,0])
b = 3*a
c = 2*b

b.retain_grad()
c.retain_grad()

b.register_hook(lambda grad: grad*mask)  

c.mean().backward() 
print(a.grad,b.grad,c.grad)

Here, only a.grad gets masked. Both a.grad and b.grad get masked only when I use

c.register_hook(lambda grad: grad*mask)  

Why is that? Shouldnt b.grad get masked even in the first scenario?

The thing is that the retain_grad() is also a hook :wink: And hooks are executed in the order they are added. So if you add your masking after the retain_grad, the gradient seen by retain_grad won’t be masked.

Thanks a lot! Its clear now!