Hello, I am looking for a way to backpropagate with respect to some mask matrix, which weights (let’s say weights from torch.nn.Linear or Conv2d) are multiplicated by. The problem is, the weights are Parameter’s class, thus leaf nodes.
Right now I am doing it like this before backpropagading through the mask:
temp = model[layer_nr].weight.data
del model[layer_nr].weight
model[layer_nr].weight = temp
And later, when I want to go back to backprop with respect to parameters: model[layer_nr].weight = nn.Parameter(model[layer_nr].weight)
I am looking for a better, simpler way. Is there something in PyTorch to help me do such a thing?
If I understand correctly, sometimes you are trying to backpropagate with respect to a mask (an input?) and sometimes with respect to the parameters, but not both at the same time?
Instead of deleting them, you can just set requires_grad=False when trying to get the mask gradient. Something like:
for param in model.parameters():
param.requires_grad = False
mask = ... # mask should require grad
loss = model(input1, mask)
loss.backward() # get gradient for mask
...
for param in model.parameters():
param.requires_grad = True
loss = ...
loss.backward() # get gradient for parameters
Thank you for your answer! Actually I am trying to replicate a specific pruning technique which requires backpropagating wrt mask. I will provide you with a simple exmaple, not quite the same as I am implementing, but hopefully it explains my problem and my (maybe strange) solution.
model = torch.nn.Linear(2,3)
mask = torch.tensor([1.,2.], requires_grad=True)
input = torch.tensor([0.5,0.3])
model.weight.data = model.weight * mask ##this is something like multiplying weights through my pruning mask
output = model.forward(input)
loss = sum(output)
loss.backward()
print(mask.grad)
#None
Please don’t try to make sense from this piece of code, it purpose is to show, from what I’d like to have gradient. At this moment to get desired result I do:
model = torch.nn.Linear(2,3)
mask = torch.tensor([1.,2.], requires_grad=True)
temp = model.weight.data # new code
del model.weight
model.weight = temp
input = torch.tensor([0.5,0.3])
model.weight = model.weight * mask
output = model.forward(input)
loss = sum(output)
loss.backward()
print(mask.grad)
#Actual value
Is there a way to do it better? To ‘force’ Parameter class to be a non leaf node without deleting it? Because it ruins my model a bit.
I’ve set up a function which could possibly help you out.
def revert_mask(maskdict, net, netdict_orig):
'''
Reverts certain elements to it's original state.
Values with TRUE in mask are reverted to original state.
Values with FALSE in mask are forced to be at 0.
'''
for name, param in net.named_parameters():
mask = maskdict[name]
param.data[mask] = netdict_orig[name][mask]
param.data[~mask] = 0.0 # Force values at stay at 0
return net
def update_grad(net, maskdict):
for name, param in net.named_parameters():
mask = maskdict[name]
param.grad[~mask] = 0 # Zero out grad values.
return net
Given a mask, it’ll zero out weights (and gradients, if you need to - I’ve kept them in different functions). Hope this helps.