Backpropagate with respect to mask

Hello, I am looking for a way to backpropagate with respect to some mask matrix, which weights (let’s say weights from torch.nn.Linear or Conv2d) are multiplicated by. The problem is, the weights are Parameter’s class, thus leaf nodes.

Right now I am doing it like this before backpropagading through the mask:

        temp = model[layer_nr].weight.data
        del model[layer_nr].weight
        model[layer_nr].weight = temp

And later, when I want to go back to backprop with respect to parameters:
model[layer_nr].weight = nn.Parameter(model[layer_nr].weight)

I am looking for a better, simpler way. Is there something in PyTorch to help me do such a thing?

If I understand correctly, sometimes you are trying to backpropagate with respect to a mask (an input?) and sometimes with respect to the parameters, but not both at the same time?

Instead of deleting them, you can just set requires_grad=False when trying to get the mask gradient. Something like:

for param in model.parameters():
   param.requires_grad = False

mask = ... # mask should require grad
loss = model(input1, mask)
loss.backward()  # get gradient for mask

...

for param in model.parameters():
   param.requires_grad = True
loss = ...
loss.backward()  # get gradient for parameters

Thank you for your answer! Actually I am trying to replicate a specific pruning technique which requires backpropagating wrt mask. I will provide you with a simple exmaple, not quite the same as I am implementing, but hopefully it explains my problem and my (maybe strange) solution.

model = torch.nn.Linear(2,3)
mask = torch.tensor([1.,2.], requires_grad=True)

input = torch.tensor([0.5,0.3])
model.weight.data = model.weight * mask  ##this is something like multiplying weights through my pruning mask
output = model.forward(input)
loss = sum(output)

loss.backward()
print(mask.grad)
#None

Please don’t try to make sense from this piece of code, it purpose is to show, from what I’d like to have gradient. At this moment to get desired result I do:

model = torch.nn.Linear(2,3)
mask = torch.tensor([1.,2.], requires_grad=True)

temp = model.weight.data # new code
del model.weight
model.weight = temp

input = torch.tensor([0.5,0.3])
model.weight = model.weight * mask
output = model.forward(input)
loss = sum(output)

loss.backward()
print(mask.grad)
#Actual value

Is there a way to do it better? To ‘force’ Parameter class to be a non leaf node without deleting it? Because it ruins my model a bit.

I’ve set up a function which could possibly help you out.

def revert_mask(maskdict, net, netdict_orig):
    '''
    Reverts certain elements to it's original state.
    Values with TRUE in mask are reverted to original state.
    Values with FALSE in mask are forced to be at 0.
    '''
    for name, param in net.named_parameters():
        mask = maskdict[name]
        param.data[mask] = netdict_orig[name][mask]
        param.data[~mask] = 0.0 # Force values at stay at 0
    return net

def update_grad(net, maskdict):
    for name, param in net.named_parameters():
        mask = maskdict[name]
        param.grad[~mask] = 0 # Zero out grad values.
    return net

Given a mask, it’ll zero out weights (and gradients, if you need to - I’ve kept them in different functions). Hope this helps.