Hi Soroush!
Just to be clear why masks are not used in the model:
As mentioned above, this code:
for idx, param in enumerate(model.parameters()):
param = param * masks[idx]
creates new tensors but does not modify the model.parameters() tensors.
In this context, param = param * masks[idx] is effectively equivalent to
new_tensor_that_will_be_ignored and_discarded = param * masks[idx].
Consider:
>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> lin = torch.nn.Linear (2, 3) # a pytorch "model"
>>>
>>> lin.weight # its first parameter
Parameter containing:
tensor([[-0.1004, 0.3112],
[ 0.6338, -0.0288],
[ 0.0585, 0.6938]], requires_grad=True)
>>> lin.bias # its second parameter
Parameter containing:
tensor([-0.1293, -0.3984, -0.4478], requires_grad=True)
>>>
>>> m1 = torch.randn (3, 2, requires_grad = True) # first mask
>>> m2 = torch.randn (3, requires_grad = True) # second mask
>>> masks = [m1, m2] # packaged as a list
>>>
>>> x = torch.randn (2) # some input to the model
>>>
>>> lin (x) # output before attempted modification
tensor([-0.4202, -0.0696, -0.9600], grad_fn=<AddBackward0>)
>>>
>>> for idx, param in enumerate(lin.parameters()): # failed attempt at modification
... param = param * masks[idx] # creates and discards new tensors
...
>>> lin.weight # model parameters unchanged
Parameter containing:
tensor([[-0.1004, 0.3112],
[ 0.6338, -0.0288],
[ 0.0585, 0.6938]], requires_grad=True)
>>> lin.bias # model parameters unchanged
Parameter containing:
tensor([-0.1293, -0.3984, -0.4478], requires_grad=True)
>>>
>>> lin (x) # output unchanged after attempted modification
tensor([-0.4202, -0.0696, -0.9600], grad_fn=<AddBackward0>)
>>>
>>> loss = lin (x).sum() # some dummy loss
>>> loss.backward() # backpropagate
>>>
>>> m1.grad # no grads for masks
>>> m2.grad # no grads for masks
>>>
>>> lin.weight.grad # but model parameters do have grads
tensor([[ 0.4833, -0.7791],
[ 0.4833, -0.7791],
[ 0.4833, -0.7791]])
>>> lin.bias.grad # but model parameters do have grads
tensor([1., 1., 1.])
Here’s how to implement this so that masks does contribute to model and
that you can compute gradients with respect to masks.
Normally, pytorch will not let you modify model parameters inplace, but if
you freeze the parameters (as you might be doing for pretrained weights)
by setting their requires_grad = False, you can. Then you can multiply
by masks inplace, param.mul_ (masks[idx]), and things will work:
>>> lin.weight.grad = None # clear model grads
>>> lin.bias.grad = None # clear model grads
>>>
>>> lin.weight.requires_grad = False # freeze model parameters
>>> lin.bias.requires_grad = False # by setting requires_grad = True
>>>
>>> lin.weight # requires_grad is False (so doesn't display)
Parameter containing:
tensor([[-0.1004, 0.3112],
[ 0.6338, -0.0288],
[ 0.0585, 0.6938]])
>>> lin.bias # requires_grad is False (so doesn't display)
Parameter containing:
tensor([-0.1293, -0.3984, -0.4478])
>>>
>>> for idx, param in enumerate(lin.parameters()): # successful parameter modification
... _ = param.mul_ (masks[idx]) # modify actual parameters inplace
...
>>> lin.weight # model parameters modified
Parameter containing:
tensor([[ 0.0377, 0.5826],
[-0.1272, 0.0178],
[-0.0097, -0.5214]], grad_fn=<MulBackward0>)
>>> lin.bias # and have grad_fn because of modification
Parameter containing:
tensor([-0.1900, -0.1723, 0.0791], grad_fn=<MulBackward0>)
>>>
>>> lin (x) # changed output reflects changed parameters
tensor([-0.6257, -0.2476, 0.4806], grad_fn=<AddBackward0>)
>>>
>>> loss = lin (x).sum() # some dummy loss
>>> loss.backward() # backpropagate
>>>
>>> m1.grad # now gradients flow back to masks
tensor([[-0.0485, -0.2425],
[ 0.3063, 0.0224],
[ 0.0283, -0.5405]])
>>> m2.grad # now gradients flow back to masks
tensor([-0.1293, -0.3984, -0.4478])
Best.
K. Frank