Extending parameter tensor dimension

Hi,

I am trying to extend the size of a nn.parameter, but without creating a new one. Is it possible?

Here is a simplified code of what I want to do:

import torch

w = torch.nn.Parameter(torch.randn((2,5)),requires_grad=True)

l = torch.sum(w)
l.backward()

# if "del l" is added here then the code runs 

ext_w = torch.randn((1,5),requires_grad=True) # now the shape is (3,5)
w.data = torch.cat((w.data, ext_w), dim=0)
w.requires_grad = True
w.grad= None

l = torch.sum(w)
l.backward()

if I run this code I get this error:

RuntimeError: Function SumBackward0 returned an invalid gradient at index 0 - got [3, 5] but expected shape compatible with [2, 5]

my question is why is this happening? which object is expecting a shape [2,5]?

The weird thing is that, I tried deleting the l variable (del l) before extending the w parameter, and in that case it seems to work

thanks a lot in advance

Hi,

You should never use .data as it is unsafe and as you can see, leads to weird errors.

but without creating a new one. Is it possible?

What do you mean by this? You actually create a brand new Tensor there.

Hi,

thanks for the quick answer.

well, I do not create a new parameter object. The longer story is that the parameter is inside a model that is training, and an optimizer is used for training. I do not want to create a new parameter because then I would also need to re-create a new optimizer each time, whereas if I could only expand the inner data I could keep using the same optimizer.

(I adapted the optimizer code to deal with “expanding” gradients, needed for methods like sgd with momentum)

In that case you can use set:

import torch

w = torch.nn.Parameter(torch.randn((2,5)),requires_grad=True)

l = torch.sum(w)
l.backward()

# if "del l" is added here then the code runs 

ext_w = torch.randn((1,5),requires_grad=True) # now the shape is (3,5)
with torch.no_grad():
    w.set_(ext_w)

l = torch.sum(w)
l.backward()

And you will need to update the optimizer if it has state as well :confused:

1 Like