I want to mask out the gradients computed by loss function before backpropagating it further. Below I have written a piece of code to explain what I am trying to do:
import torch
model = torch.nn.Linear(2, 4)
input = torch.rand(2)
target = torch.rand([2,2])
mask = torch.Tensor([[True, False],
[True, True]])
model.train()
for i in range(3):
output = model(input)
output = torch.reshape(output, [2,2])
loss = torch.nn.functional.mse_loss(target, output)
loss.backward()
My mask is of the shape [2,2]. Backward function corresponding to mse loss would give me a grad of shape [2,2]. I want to mask this gradient using the mask tensor before backpropagating it further.
Could someone please give me an insight into how to do this ?
I have followed some topics related to this. But none of them solves my problem