Want to Optimize Only Some Variables but Getting 'can't optimize a non-leaf Tensor' Error

Hi all,

I want to optimize only some variables. For example, create a Tensor, and then pass only a part of it to the optimizer. However, I am getting can't optimize a non-leaf Tensor error. Is there a way to do what I am trying to do?

I have attached a minimal example below. In this case, I want to optimize the source vector such that it matches the target, but I want only the middle 6 elements/parameters of the source vector to be optimized/change, and I don’t want to use masked loss to accomplish this.

import torch
import torch.optim as optim
import torch.nn as nn


target = torch.randn(10)
print('target: ', target)

# I want to optimize the source vector such that it matches the target
source = torch.Tensor(10).requires_grad_(True)
print('source: ', source)

# making sure source requires grad, while target doesn't
print('target req grad: ', target.requires_grad)
print('source req grad: ', source.requires_grad)

# passing only some parameters (middle 6, in this case) to the optimizer
optimizer = optim.SGD([source[2:-2]], lr=0.001)

criterion = nn.L1Loss()

starting_diff = criterion(target, source)
print('starting diff: ', starting_diff)
# actual optimization loop
while criterion(target, source) > 0.01:
    optimizer.zero_grad()
    loss = criterion(target, source)
    print('loss: ', loss)
    loss.backward()
    optimizer.step()

No, this won’t be possible as slicing a tensor will create a non-leaf tensor. You could create the sliced tensor with requires_grad=True, pass it to the optimizer, and concatenate or stack it with the other tensor before using it in the actual operation. Alternatively, you could also try to zero out the gradients of the parts of the tensor which should not be updated.

Hi @ptrblck, Thank you very much! Concatenating approach works perfectly! I also want to give ‘zeroing out the gradients’ approach a try, but not sure how to zero out the gradients of individual parameters/parts of the tensor. Could you please tell?

You could directly assign zeros to the .grad slice e.g. via lin.weight.grad[2:4, 2:4] = 0., but you would have to be careful about this approach. If you are using an optimizer with running stats (such as Adam) and these parameters were already updated before (i.e. the running stats contain valid values for the gradient slice) then they would still be updated even if the gradient was set to zero.
I would thus claim that your current approach is less error prone.
Here is a small example:

# setup
lin = nn.Linear(10, 10, bias=False)
optimizer = torch.optim.Adam(lin.parameters(), lr=1.)
x = torch.randn(1, 10)

# zero gradients of parameters which were never updated
out = lin(x)
out.mean().backward()
lin.weight.grad[2:4, 2:4] = 0.

print(lin.weight[2:4, 2:4])
optimizer.step()
print(lin.weight[2:4, 2:4]) # equal
optimizer.zero_grad()

# full update
out = lin(x)
out.mean().backward()
print(lin.weight[2:4, 2:4])
optimizer.step()
print(lin.weight[2:4, 2:4]) # updated
optimizer.zero_grad()

# zeroing out the gradients is not sufficient anymore
out = lin(x)
out.mean().backward()
lin.weight.grad[2:4, 2:4] = 0.

print(lin.weight[2:4, 2:4])
optimizer.step()
print(lin.weight[2:4, 2:4]) # updated !!!
optimizer.zero_grad()
1 Like

Just wondering; the error goes away by slicing in a with torch.no_grad(): block. Does this however disable learning for these params?

My use-case is I want to apply a different learning rate to some parameters of a layer (Transformer token embeddings), so just setting the grad to 0 does not cut it.

That’s expected as you are creating a new leaf tensor, which is valid.

The new sliced parameter can be trained, but is of course detached from the model’s parameter. Optimizing the new (sliced) parameter will thus not change anything in the model.

Here is a small example:

lin = nn.Linear(10, 10, bias=False)
param = lin.weight

optimizer = torch.optim.Adam([param], lr=1.) # works

p = param[:2, :2]
optimizer = torch.optim.Adam([p], lr=1.) # fails
# ValueError: can't optimize a non-leaf Tensor

with torch.no_grad():
    p = param[:2, :2]
optimizer = torch.optim.Adam([p], lr=1.) # works, but uses new leaf tensor

# check if lin.parameters are updated
out = lin(torch.randn(1, 10))
out.mean().backward()

p0 = lin.weight.clone()

# entire parameter has valid gradients as expected
for name, param in lin.named_parameters():
    print(name, param.grad)

# sliced parameter was never used, thus no gradients
print(p.grad)
# None 

# does not update model
optimizer.step() 

p1 = lin.weight.clone()

print(p1 - p0)
# all zeros - no update

You might need to create the parameters from different slices in the forward pass using e.g. torch.cat or torch.stack and optimize the sliced using the different learning rates separately.
Here is another small example:

class MyLinear(nn.Module):
    def __init__(self):
        super().__init__()
        self.p0 = nn.Parameter(torch.randn(10, 5))
        self.p1 = nn.Parameter(torch.randn(10, 5))

    def forward(self, x):
        p = torch.cat((self.p0, self.p1), dim=1)
        out = F.linear(x, p)
        return out
    
lin = MyLinear()

optimizer = torch.optim.Adam([
    {'params': [lin.p0], 'lr': 1.},
    {'params': [lin.p1], 'lr': 1e-3},
])

x = torch.randn(1, 10)
out = lin(x)
out.mean().backward()

for name, param in lin.named_parameters():
    print(name, param.grad)

p00 = lin.p0.clone()
p10 = lin.p1.clone()

optimizer.step()

p01 = lin.p0.clone()
p11 = lin.p1.clone()

print(p01 - p00)
# tensor([[ 1.0000, -1.0000, -1.0000,  1.0000,  1.0000],
#         [ 1.0000, -1.0000, -1.0000,  1.0000,  1.0000],
#         [ 1.0000, -1.0000, -1.0000,  1.0000,  1.0000],
#         [ 1.0000, -1.0000, -1.0000,  1.0000,  1.0000],
#         [ 1.0000, -1.0000, -1.0000,  1.0000,  1.0000],
#         [ 1.0000, -1.0000, -1.0000,  1.0000,  1.0000],
#         [ 1.0000, -1.0000, -1.0000,  1.0000,  1.0000],
#         [ 1.0000, -1.0000, -1.0000,  1.0000,  1.0000],
#         [ 1.0000, -1.0000, -1.0000,  1.0000,  1.0000],
#         [ 1.0000, -1.0000, -1.0000,  1.0000,  1.0000]], grad_fn=<SubBackward0>)

print(p11 - p10)
# tensor([[0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
#         [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
#         [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
#         [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
#         [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
#         [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
#         [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
#         [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
#         [0.0010, 0.0010, 0.0010, 0.0010, 0.0010],
#         [0.0010, 0.0010, 0.0010, 0.0010, 0.0010]], grad_fn=<SubBackward0>)
1 Like

Very clear explanation, thank you!

1 Like