Optimizing a mask for weights instead of weights themselves

Hi. I’m trying to optimize a mask for each weight of a simple DNN without updating the weights directly. My goal is to find a mask than I can multiply to network weights and achieve the maximum accuracy. This mask is the same size of the weights and denotes a coefficient of each weight parameter. The code I implemented is as follows:

learning_rate = 1
num_iterations = 3

weights = list(model.parameters())
masks = [torch.rand_like(layer.data, requires_grad=True) for layer in model.parameters()]

optimizer = torch.optim.Adam(indicators, lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()

for _ in range(num_iterations):
    for x, y in self.test_loader:
        optimizer.zero_grad()

        for idx, param in enumerate(model.parameters()):
            param = param * masks[idx]

        x = model(x)
        loss = criterion(x, y)

        loss.backward()

        print("mask", masks[-1].data)
        print("grad", masks[-1].grad)

        optimizer.step()

        with torch.no_grad():
            masks = [mask.clamp_(0.0, 1.0) for mask in masks]

I multiply the masks by the weights and then feed a data to the model. then backward the loss and since only masks are grad required, I expect it to calculate the gradients of masks. But this code gives “None” for all gradients of masks.
I’ve spent lots of time to resolve it but haven’t been successful yet.

Hi Soroush!

This is more of a pure python issue than one of pytorch in that you are
setting a python name (a “variable”) to refer to a new object, rather than
modifying the object to which that name previously referred.

The short story is that you are not actually modifying your model weights,
so backpropagating through model doesn’t compute the gradients you
expect.

As as aside, why? Why not optimize your model weights directly?

At this point param is a name in your python script that refers to one of the
parameters in your model.

On the right-hand side of the equals sign a new tensor is created by multiplying
together the model parameter referred to by param and masks[idx]. The
equals sign then sets the name param to refer to this new tensor. After this
line of code, param no longer refers to a parameter of your model.

Because masks[idx] has requires_grad = True (as presumably do your
model parameters), the new tensor to which param now refers also has
requires_grad = True, but this doesn’t matter, because you never use
(the tensor referred to by) param again.

x and loss depend on your model parameters, but they do not depend
on masks because your model parameters were not updated to depend
on masks. That is, your model does not depend of the new tensor created
by param * masks[idx], so neither do x nor loss.

loss does not depend on masks, so loss.backward() does not compute
gradients with respect to masks.

As outlined above, gradients with respect to masks are not computed, so,
indeed, masks[-1].grad remains None.

Best.

K. Frank

Thank you Frank for your thorough explanation. But to be honest I still don’t know how to fix it :sweat_smile:
To answer your first question

This code is part of a bigger code that aims to prune the network parameters. So this mask tensor[s] is going to be used in order to prune some parameters later. And I forgot to mention that the model weights are pretrained.

Second, you’re correct that the mask is not directly used in the model according to the provided code snippet, so I can’t backward the gradient of loss w.r.t it. But, I’m following the notion of backpropagation. Please refine me if it’s not correct. In an ordinary case, we feed forward like this:

A = ReLU(W . X + b)

but in my case, I’m calculating it like this:

A’ = ReLU((M * W) . X + b)

so I think masks M is somehow contributing in the model, so I can calculate gradients w.r.t M. But I don’t know how exactly!

Hi Soroush!

Just to be clear why masks are not used in the model:

As mentioned above, this code:

        for idx, param in enumerate(model.parameters()):
            param = param * masks[idx]

creates new tensors but does not modify the model.parameters() tensors.

In this context, param = param * masks[idx] is effectively equivalent to
new_tensor_that_will_be_ignored and_discarded = param * masks[idx].

Consider:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> lin = torch.nn.Linear (2, 3)                     # a pytorch "model"
>>>
>>> lin.weight                                       # its first parameter
Parameter containing:
tensor([[-0.1004,  0.3112],
        [ 0.6338, -0.0288],
        [ 0.0585,  0.6938]], requires_grad=True)
>>> lin.bias                                         # its second parameter
Parameter containing:
tensor([-0.1293, -0.3984, -0.4478], requires_grad=True)
>>>
>>> m1 = torch.randn (3, 2, requires_grad = True)    # first mask
>>> m2 = torch.randn (3, requires_grad = True)       # second mask
>>> masks = [m1, m2]                                 # packaged as a list
>>>
>>> x = torch.randn (2)                              # some input to the model
>>>
>>> lin (x)                                          # output before attempted modification
tensor([-0.4202, -0.0696, -0.9600], grad_fn=<AddBackward0>)
>>>
>>> for idx, param in enumerate(lin.parameters()):   # failed attempt at modification
...     param = param * masks[idx]                   # creates and discards new tensors
...
>>> lin.weight                                       # model parameters unchanged
Parameter containing:
tensor([[-0.1004,  0.3112],
        [ 0.6338, -0.0288],
        [ 0.0585,  0.6938]], requires_grad=True)
>>> lin.bias                                         # model parameters unchanged
Parameter containing:
tensor([-0.1293, -0.3984, -0.4478], requires_grad=True)
>>>
>>> lin (x)                                          # output unchanged after attempted modification
tensor([-0.4202, -0.0696, -0.9600], grad_fn=<AddBackward0>)
>>>
>>> loss = lin (x).sum()                             # some dummy loss
>>> loss.backward()                                  # backpropagate
>>>
>>> m1.grad                                          # no grads for masks
>>> m2.grad                                          # no grads for masks
>>>
>>> lin.weight.grad                                  # but model parameters do have grads
tensor([[ 0.4833, -0.7791],
        [ 0.4833, -0.7791],
        [ 0.4833, -0.7791]])
>>> lin.bias.grad                                    # but model parameters do have grads
tensor([1., 1., 1.])

Here’s how to implement this so that masks does contribute to model and
that you can compute gradients with respect to masks.

Normally, pytorch will not let you modify model parameters inplace, but if
you freeze the parameters (as you might be doing for pretrained weights)
by setting their requires_grad = False, you can. Then you can multiply
by masks inplace, param.mul_ (masks[idx]), and things will work:

>>> lin.weight.grad = None                           # clear model grads
>>> lin.bias.grad = None                             # clear model grads
>>>
>>> lin.weight.requires_grad = False                 # freeze model parameters
>>> lin.bias.requires_grad = False                   # by setting requires_grad = True
>>>
>>> lin.weight                                       # requires_grad is False (so doesn't display)
Parameter containing:
tensor([[-0.1004,  0.3112],
        [ 0.6338, -0.0288],
        [ 0.0585,  0.6938]])
>>> lin.bias                                         # requires_grad is False (so doesn't display)
Parameter containing:
tensor([-0.1293, -0.3984, -0.4478])
>>>
>>> for idx, param in enumerate(lin.parameters()):   # successful parameter modification
...     _ = param.mul_ (masks[idx])                  # modify actual parameters inplace
...
>>> lin.weight                                       # model parameters modified
Parameter containing:
tensor([[ 0.0377,  0.5826],
        [-0.1272,  0.0178],
        [-0.0097, -0.5214]], grad_fn=<MulBackward0>)
>>> lin.bias                                         # and have grad_fn because of modification
Parameter containing:
tensor([-0.1900, -0.1723,  0.0791], grad_fn=<MulBackward0>)
>>>
>>> lin (x)                                          # changed output reflects changed parameters
tensor([-0.6257, -0.2476,  0.4806], grad_fn=<AddBackward0>)
>>>
>>> loss = lin (x).sum()                             # some dummy loss
>>> loss.backward()                                  # backpropagate
>>>
>>> m1.grad                                          # now gradients flow back to masks
tensor([[-0.0485, -0.2425],
        [ 0.3063,  0.0224],
        [ 0.0283, -0.5405]])
>>> m2.grad                                          # now gradients flow back to masks
tensor([-0.1293, -0.3984, -0.4478])

Best.

K. Frank

Thank you so much @KFrank for your descriptive example. Now I know what’s going on :pray: