ValueError: can't optimize a non-leaf Tensor?

Dear all:
i know when

  x_cuda = x_cpu.to(device)

It will trigger error:

ValueError: can’t optimize a non-leaf Tensor

when you use optimizer = optim.Adam([x_cuda]). The right way may be optimizer = optim.Adam([x_cpu]). That’s to way, we need keep both reference of x_cpu and x_cuda.

Since in most case, our program will only keep a reference of the cuda version of tensor, such as :

		self.vars = [
			# [28*28, 512]
			torch.ones(512, 2 * 28 * 28, requires_grad=True).to(device),
			torch.zeros(512, requires_grad=True).to(device),
			# [512, 256]
			torch.ones(256, 512, requires_grad=True).to(device),
			torch.zeros(256, requires_grad=True).to(device),
			# [256, n]
			torch.ones(n_class, 256, requires_grad=True).to(device),
			torch.zeros(n_class, requires_grad=True).to(device)
		]

So i wonder how to pass the parameters to optimizer when We dnt want to keep the reference of cpu version of tensor?

4 Likes

Hi,

A leaf Variable is a variable that is at the beginning of the graph. That means that no operation tracked by the autograd engine created it.
This is what you want when you optimize neural networks as it is usually your weights or input.

So to be able to give weights to the optimizer, they should follow the definition of leaf variable above.

a = torch.rand(10, requires_grad=True) # a is a leaf variable
a = torch.rand(10, requires_grad=True).double() # a is NOT a leaf variable as it was created by the operation that cast a float tensor into a double tensor
a = torch.rand(10).requires_grad_().double() # equivalent to the formulation just above: not a leaf variable
a = torch.rand(10).double() # a does not require gradients and has not operation creating it (tracked by the autograd engine).
a = torch.rand(10).doube().requires_grad_() # a requires gradients and has no operations creating it: it's a leaf variable and can be given to an optimizer.
a = torch.rand(10, requires_grad=True, device="cuda") # a requires grad, has not operation creating it: it's a leaf variable as well and can be given to an optimizer

So in your case, you want to use the last line.

49 Likes

@albanD Thanks, you solved my problem in another approach.
However, If i must write tensor as a = torch.tensor(..).to(device), like the following linear regression example:

def main2():

	device = torch.device('cuda')
	lr = 5e-2

	w_ = torch.tensor(1., requires_grad=True)
	b_ = torch.tensor(0.1, requires_grad=True)
	w = w_.to(device)
	b = b_.to(device)


	criteon = nn.MSELoss()
	optimizer = optim.Adam([w_, b_], lr=lr)

	for i in range(500):
		x = torch.rand(1).to(device)[0]


		pred = w * x + b
		y = 2 * x + 3
		loss = criteon(pred, y)

		grads = torch.autograd.grad(loss, [w_, b_])
		w_.grad.fill_(grads[0])
		b_.grad.fill_(grads[1])
		optimizer.step()

		print(w_, b_)

It print the Error:

Traceback (most recent call last):
File “/home/i/ncrs/aaai/fsgan/test2.py”, line 94, in
main2()
File “/home/i/ncrs/aaai/fsgan/test2.py”, line 60, in main2
w_.grad.fill_(grads[0])
AttributeError: ‘NoneType’ object has no attribute ‘fill_’

What should I do? How to make the above episode code work as normal?

The .grad field are allocated lazily. So if they are None, you can do w_.grad = grads[0] I think.
When they already exist, you don’t want this copy to be tracked by the autograd actualy so you should do:

grads = torch.autograd.grad(loss, [w_, b_])
with torch.no_grad():
    w_.grad.fill_(grads[0])
    b_.grad.fill_(grads[1])
optimizer.step()

@albanD Thanks so much.
One more question: I dnt understand why it’s nessary to use with torch.no_grad()?
Could u kindly explain some case going wrong if no use torch.no_grad()?

Thinking about it, w_.grad = grads[0] is not a great idea either :smiley: You should do w_.grad = grads[0].detach().
Otherwise, same as if you were not using torch.no_grad() you would create a memory leak if grads.requires_grad is True.

Basically, a computational graph is kept in memory as long as it could be used for a backward.
For example if you use a net to compute a scalar loss. Then as long as you have your loss that exist, all the computational graph and intermediary results associated will stay in memory.
If you do w_.grad.fill_(grads[0]), then w_.grad becomes reponsible for the whole graph. And since w_.grad is never deleted, the graph will remain forever.

In your case, grads should not require gradient as you don’t pass the corresponding flags to autograd.grad.

In general, I like it as it makes the code more readable (you are sure this block is not part of your model) and future proof: if one day you change something that makes grads require grad, it will still work and not create a memory leak.

2 Likes

How to do so when loading a PIL image?

In my case I am loading a PIL image and then using pre-configured transform object:

img = Image.open(pretrained_img_path)

        noise = transform(img).unsqueeze(0)

        noise.requires_grad = True

        print(noise.requires_grad)

        noise = noise.to(device)

If you want the final noise to be the leaf, you want to do noise.requires_grad = True after moving it to the device with noise = noise.to(device).

2 Likes

Ah. Is there a more “Pytorch” way to do this when loading an image from PIL. Something possible in one line, where I can assign requires grad = True as I do when creating new variables. I couldn’t find anything within the docs here:

https://pytorch.org/docs/stable/torchvision/transforms.html

I am currently using this transform:

transform = transforms.Compose([transforms.Resize((img_dim, img_dim)),
                                    transforms.ToTensor(), 
                                    transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])])
    

Currently my way of trying to load a PIL image as a trackable tensor seems very sloppy.

@albanD Thank you for the solution in the last line a = torch.rand(10, requires_grad=True, device="cuda") !
How does this work for a torch.nn.Linear(in_features: int, out_features: int, bias: bool = True) when I want to create a tensor with the linear transformation?

Could you explain what should be altered in this example here to create the linear transformation tensor on CUDA?

>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)

Hi,

For the input, you can do the same thing: input = torch.randn(128, 20, device="cuda") But for the nn.Module you will have to first create it and then send it to the GPU:

m = nn.Linear(20, 30)
m.cuda()
1 Like