Holding certain weights fixed during training

Hello,
I have a certain tensor calculated during the forward pass and I want some of its values to contribute nothing to the training of the model - neither in the forward nor in the backward pass.
If I understand correctly, the way to do it is by adding these lines to __forward__():

t = input_a @ mat_b # Create tensor
with torch.no_grad():
    t[i,j] = 0 # eliminate passing any information during FP
mask = torch.ones_like(t)
mask[i,j] = 0
t.register_hook(lambda grad: grad.mul_(mask)) # Set gradient to 0 so it doesn't affect BP

Am I correct? Is there a better way to do so?

Thanks!

2 Likes