Hi Soroush!
Just to be clear why masks
are not used in the model:
As mentioned above, this code:
for idx, param in enumerate(model.parameters()):
param = param * masks[idx]
creates new tensors but does not modify the model.parameters()
tensors.
In this context, param = param * masks[idx]
is effectively equivalent to
new_tensor_that_will_be_ignored and_discarded = param * masks[idx]
.
Consider:
>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> lin = torch.nn.Linear (2, 3) # a pytorch "model"
>>>
>>> lin.weight # its first parameter
Parameter containing:
tensor([[-0.1004, 0.3112],
[ 0.6338, -0.0288],
[ 0.0585, 0.6938]], requires_grad=True)
>>> lin.bias # its second parameter
Parameter containing:
tensor([-0.1293, -0.3984, -0.4478], requires_grad=True)
>>>
>>> m1 = torch.randn (3, 2, requires_grad = True) # first mask
>>> m2 = torch.randn (3, requires_grad = True) # second mask
>>> masks = [m1, m2] # packaged as a list
>>>
>>> x = torch.randn (2) # some input to the model
>>>
>>> lin (x) # output before attempted modification
tensor([-0.4202, -0.0696, -0.9600], grad_fn=<AddBackward0>)
>>>
>>> for idx, param in enumerate(lin.parameters()): # failed attempt at modification
... param = param * masks[idx] # creates and discards new tensors
...
>>> lin.weight # model parameters unchanged
Parameter containing:
tensor([[-0.1004, 0.3112],
[ 0.6338, -0.0288],
[ 0.0585, 0.6938]], requires_grad=True)
>>> lin.bias # model parameters unchanged
Parameter containing:
tensor([-0.1293, -0.3984, -0.4478], requires_grad=True)
>>>
>>> lin (x) # output unchanged after attempted modification
tensor([-0.4202, -0.0696, -0.9600], grad_fn=<AddBackward0>)
>>>
>>> loss = lin (x).sum() # some dummy loss
>>> loss.backward() # backpropagate
>>>
>>> m1.grad # no grads for masks
>>> m2.grad # no grads for masks
>>>
>>> lin.weight.grad # but model parameters do have grads
tensor([[ 0.4833, -0.7791],
[ 0.4833, -0.7791],
[ 0.4833, -0.7791]])
>>> lin.bias.grad # but model parameters do have grads
tensor([1., 1., 1.])
Here’s how to implement this so that masks
does contribute to model
and
that you can compute gradients with respect to masks
.
Normally, pytorch will not let you modify model parameters inplace, but if
you freeze the parameters (as you might be doing for pretrained weights)
by setting their requires_grad = False
, you can. Then you can multiply
by masks
inplace, param.mul_ (masks[idx])
, and things will work:
>>> lin.weight.grad = None # clear model grads
>>> lin.bias.grad = None # clear model grads
>>>
>>> lin.weight.requires_grad = False # freeze model parameters
>>> lin.bias.requires_grad = False # by setting requires_grad = True
>>>
>>> lin.weight # requires_grad is False (so doesn't display)
Parameter containing:
tensor([[-0.1004, 0.3112],
[ 0.6338, -0.0288],
[ 0.0585, 0.6938]])
>>> lin.bias # requires_grad is False (so doesn't display)
Parameter containing:
tensor([-0.1293, -0.3984, -0.4478])
>>>
>>> for idx, param in enumerate(lin.parameters()): # successful parameter modification
... _ = param.mul_ (masks[idx]) # modify actual parameters inplace
...
>>> lin.weight # model parameters modified
Parameter containing:
tensor([[ 0.0377, 0.5826],
[-0.1272, 0.0178],
[-0.0097, -0.5214]], grad_fn=<MulBackward0>)
>>> lin.bias # and have grad_fn because of modification
Parameter containing:
tensor([-0.1900, -0.1723, 0.0791], grad_fn=<MulBackward0>)
>>>
>>> lin (x) # changed output reflects changed parameters
tensor([-0.6257, -0.2476, 0.4806], grad_fn=<AddBackward0>)
>>>
>>> loss = lin (x).sum() # some dummy loss
>>> loss.backward() # backpropagate
>>>
>>> m1.grad # now gradients flow back to masks
tensor([[-0.0485, -0.2425],
[ 0.3063, 0.0224],
[ 0.0283, -0.5405]])
>>> m2.grad # now gradients flow back to masks
tensor([-0.1293, -0.3984, -0.4478])
Best.
K. Frank