Getting an error while optimizing for a parameter: can't optimize a non-leaf Tensor

I am trying to optimize over the value z and netG is trained GAN model.

Here is the code. I am getting an error. I want to min the loss between the generated images x and the z. something like this:

maxEpochs = 100
z = nn.Parameter(torch.rand(64,100,1,1), requires_grad=True).to(device)
opt = torch.optim.Adam([z]).to(device)

for e in range(maxEpochs):
    netG.eval()
    g_z = netG(z)
    generatedImages = fakeimages.to(device)
    loss = nn.MSELoss()(generatedImages, g_z)
    # calculate gradient
    opt.zero_grad()
    loss.backward()
    opt.step()

    losses['rec'].append(loss)
    if e % 100 == 0:
        print( '[%d] epoch: %0.5f, Loss: %0.5f' % (e, loss))

I am getting the error     

 ValueError("can't optimize a non-leaf Tensor")
Can anyone help me out in resolving the error?

Hi Iram!

Let’s start with the first three lines of the code you posted:

Let me assume that device refers to a gpu.

Your second line of code creates a tensor on the cpu and then
“moves” it to the gpu by creating a gpu copy of the original cpu
tensor. Although the original cpu tensor is a “leaf” tensor, the
copy-to-gpu operation counts as a “computation” so that the
gpu tensor is not a leaf tensor, hence the error you report:

Consider:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> x = torch.nn.Parameter(torch.rand(64,100,1,1), requires_grad=True)
>>> x.is_cuda
False
>>> x.is_leaf
True
>>> y = torch.nn.Parameter(torch.rand(64,100,1,1, device = 'cuda'), requires_grad=True)
>>> y.is_cuda
True
>>> y.is_leaf
True
>>> z = torch.nn.Parameter(torch.rand(64,100,1,1), requires_grad=True).to ('cuda')
>>> z.is_cuda
True
>>> z.is_leaf
False

As an aside, your third line of code, as posted, is fully bogus, and
will throw an error, even if you try to construct your Adam optimizer
with a leaf tensor. (In general, a pytorch Optimizer doesn’t have a
.to() method.)

>>> opt = torch.optim.Adam([x]).to ('cuda')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'Adam' object has no attribute 'to'

Construct the Parameter you wish to optimize directly on the gpu, as
is done with y in the example I gave above, and then use it to construct
the Optimizer (but don’t attempt to “move” the Optimizer):

>>> opt = torch.optim.Adam([y])
>>> opt.param_groups[0]['params'][0].shape
torch.Size([64, 100, 1, 1])
>>> opt.param_groups[0]['params'][0].is_cuda
True
>>> opt.param_groups[0]['params'][0].is_leaf
True

Best.

K. Frank

1 Like

Yes, thank you so much for your reply. I am trying to learn it. Thanks again for your great help. Can you please answer this query as well?

Can you explain me about param_groups[0][‘params’][0] ?

Hi Iram!

Looking at an object’s __dict__ property is often a good way to
probe all of its properties. So start with:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> opt = torch.optim.Adam ([torch.randn (3, requires_grad = True)])
>>> opt.__dict__
{'defaults': {'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, '_zero_grad_profile_name': 'Optimizer.zero_grad#Adam.zero_grad', 'state': defaultdict(<class 'dict'>, {}), 'param_groups': [{'params': [tensor([0.2824, 0.7790, 0.3820], requires_grad=True)], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}]}

and drill down deeper from there.

Best.

K. Frank