Hello,
I have a certain tensor calculated during the forward pass and I want some of its values to contribute nothing to the training of the model - neither in the forward nor in the backward pass.
If I understand correctly, the way to do it is by adding these lines to __forward__()
:
t = input_a @ mat_b # Create tensor
with torch.no_grad():
t[i,j] = 0 # eliminate passing any information during FP
mask = torch.ones_like(t)
mask[i,j] = 0
t.register_hook(lambda grad: grad.mul_(mask)) # Set gradient to 0 so it doesn't affect BP
Am I correct? Is there a better way to do so?
Thanks!