TypeError: cannot assign ‘torch.cuda.FloatTensor’ as parameter ‘weight’ (torch.nn.Parameter or None expected)

I’m doing weight clipping, and I’d like to learn an optimal clipping threshold:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.w_max = nn.Parameter(torch.Tensor([0]), requires_grad=True)

        self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.linear = nn.Linear(64 * 5 * 5, 10)

    def forward(self, input):
        conv1 = self.conv1(input)
        pool1 = self.pool(conv1)
        relu1 = self.relu(pool1)
        conv2 = self.conv2(relu1)
        pool2 = self.pool(conv2)
        relu2 = self.relu(pool2)
        relu2 = relu2.view(relu2.size(0), -1)
        return self.linear(relu2)

model = Net()
torch.nn.init.kaiming_normal_(model.parameters)
nn.init.constant(model.w_max, 0.1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch in range(100):
    for i in range(1000):
        output = model(input)
        loss = nn.CrossEntropyLoss()(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        model.conv1.weight[model.conv1.weight >= model.w_max] = model.w_max

I want the gradient of w_max to be the sum of gradients of all weights above w_max

This works fine for activation clipping thresholds, but for weights I get

TypeError: cannot assign ‘torch.cuda.FloatTensor’ as parameter ‘weight’ (torch.nn.Parameter or None expected)

How do I do this?

1 Like

Hi,

You cannot set the content of parameters with a Tensor directly.
Maybe use masked_scatter_ or masked_fill_.

I put

self.conv1.weight.masked_scatter_(self.conv1.weight > self.w_max, self.w_max)

as the first line in forward(), but I’m getting:

RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

If this is the way you update your weights and don’t want gradients to flow back this op (which is what you want I expect), you should wrap this op in a with torch.no_grad():. This way the engine will know that you’re not trying to make a differentiable op and just changing a value of the weights:

# More code
optimizer.step()
with torch.no_grad():
  self.conv1.weight.masked_scatter_(self.conv1.weight > self.w_max, self.w_max)

No, I do want to learn both weights and the threshold. Currently I have to do it manually:

loss = loss + 0.001 * model.w_max ** 2
optimizer.zero_grad()
loss.backward()
w_max_grad = torch.sum(model.conv1.weight.grad[model.conv1.weight >= model.w_max])  
model.w_max.grad.data += w_max1_grad
optimizer.step()

Note that I force L2 penalty on the threshold growth, so I have to add the clipped weight gradients to the L2 loss gradient for the threshold before I update its value.

This works, however when I do the same thing for activation clipping, I don’t need to do anything manually - the gradients are accumulated correctly in the backward pass.

Just as a note, you should not do:

model.w_max.grad.data += w_max1_grad

But

with torch.no_grad():
    model.w_max.grad += w_max1_grad

The first one can hide very subtle bugs that will make the code run but the computed gradients will be wrong.

1 Like

Thank you! This made it more stable.

On a related note, is this a good way to do gradient clipping: w_max.grad.data.clamp_(-1, 1)
right before I manually update gradient?

Also, is this the best way to zero gradient: model.w_max.grad.data[:] = 0 ?

You can do the clipping within the torch.no_grad() block and remove the .data as well.
To zero out gradients, you can do model.zero_grad() and that will zero out all the .grad fields of all the Parameters in the net.
It does something like:

with torch.no_grad():
  for p in model.parameters():
    p.grad.zero_()