I wanted to try out different sets of optimizer hyperparameters for each element of a tensor. I tried the following but I am getting a non-leaf tensor error, possibly because I am indexing the tensor:

/usr/local/lib/python3.6/dist-packages/torch/optim/sgd.py in init(self, params, lr, momentum, dampening, weight_decay, nesterov)
62 if nesterov and (momentum <= 0 or dampening != 0):
63 raise ValueError(“Nesterov momentum requires a momentum and zero dampening”)
—> 64 super(SGD, self).init(params, defaults)
65
66 def setstate(self, state):

/usr/local/lib/python3.6/dist-packages/torch/optim/optimizer.py in init(self, params, defaults)
41
42 for param_group in param_groups:
—> 43 self.add_param_group(param_group)
44
45 def getstate(self):

/usr/local/lib/python3.6/dist-packages/torch/optim/optimizer.py in add_param_group(self, param_group)
191 "but one of the params is " + torch.typename(param))
192 if not param.is_leaf:
–> 193 raise ValueError(“can’t optimize a non-leaf Tensor”)
194
195 for name, default in self.defaults.items():

When you do y[0], the Tensor you get is not a leaf tensor anymore.
Remember a leaf Tensor is one that you created with required_grad=True (and so is not the result of an operation).
You can only optimize leaf Tensors.

If you want to use builint optimizers, you will need to create one Tensor for every parameter and combine them during the forward pass:

# Create like this
y0 = torch.ones(1, requires_grad=True)
y1 = torch.ones(1, requires_grad=True)
y2 = torch.ones(1, requires_grad=True)
opt2 = torch.optim.SGD([{'params':[y0],'lr':0.1},{'params':[y1],'lr':1},{'params':[y2],'lr':10}])
# During the forward pass:
y = torch.cat([y0, y1, y2], 0)
# The rest of your forward

i tried the snippet i am pasting and i think i will have to execute a stack or cat after each optimizer.step call. In the snippet below the changes in the leaf tensors dont seem to be communicated to the cat/stack, while changes to the stack/cat are communicated to the leaf:

a = torch.randn(1,2)
b = torch.randn(1,2)
c = torch.cat([a,b],0)
print(a)
a.data = a.data+1
print(a)
print('*'*50)
print(c)
print('\n\n')
c.data[0,:] = c.data[0,:] *10
print(c,a)

As I said, you will need to run cat for each forward pass.

The changes are never “communicated” to the input of cat or stack, these are out of place operations.
The weird behavior you see when you print a in the end is just because you changed a first.