Masking out gradient before backpropagating

I want to mask out the gradients computed by loss function before backpropagating it further. Below I have written a piece of code to explain what I am trying to do:

 import torch
 model = torch.nn.Linear(2, 4)
 input = torch.rand(2)
 target = torch.rand([2,2])
 mask = torch.Tensor([[True, False],
                     [True, True]])

 for i in range(3):
     output = model(input)
     output = torch.reshape(output, [2,2])
     loss = torch.nn.functional.mse_loss(target, output)


My mask is of the shape [2,2]. Backward function corresponding to mse loss would give me a grad of shape [2,2]. I want to mask this gradient using the mask tensor before backpropagating it further.

Could someone please give me an insight into how to do this ?

I have followed some topics related to this. But none of them solves my problem


You can do this using a hook which will return the new value to use

     output = torch.reshape(output, [2,2])
     output.register_hook(lambda grad: grad * mask.float())
     loss = torch.nn.functional.mse_loss(target, output)

Thanks. It worked for me :slight_smile: