AutoGrad not updating parameters

Hi,
I have the following code structure where I am updating the weights via manually defined parameter vector ‘u’

class Net(torch.nn.Module):
    def __init__(self, W):
        super(Net, self).__init__()
    
        dim=100
        self.hidden1 = nn.Linear(dim, dim, bias=False)
        self.output = nn.Linear(dim, 10)
    
        self.hidden1.weight.requires_grad = False
         
        self.u = torch.nn.Parameter(torch.randn(dim,requires_grad=True))
        self.U = torch.ger(self.u,self.u)
       
        self.hidden1.weight.data=(self.U*W)
    

    
   def forward(self, x,):
      x = self.hidden1(x)
      x = self.output(x)
   return F.log_softmax(x,dim=1)

model = Net(W)
output = model(data, target)
loss = F.nll_loss(output, target)
loss.backward()

However, the parameter ‘u’ is not getting updated by loss.backward(). Can somebody help what is the mistake I am doing here?

Maybe move the requieres_grad to the Parameter(), it‘s currently in the inner bracket of torch.rand() and check the is_leaf() before and after. Maybe this note in the docs helps (and the docs to .is_leaf):

https://pytorch.org/docs/stable/autograd.html#torch.Tensor.requires_grad

Hi @marcmuc,
Tried this but still it doesn’t work.

Yes the leaf is True. So how to fix this?

Hi,

The problem is with the fact that you use .data. You should not use .data anymore.
There were 2 possible use case for .data and their new versions are:
If you want to do ops that are not registered by the autograd engine, use with torch.no_grad():. If you want to break the computational graph to prevent gradients from flowing back, use .detach().

You will see that what you try to do is not a valid use of .data. And in particular, you hide the fact that you changed the weights from the autograd engine.
You will need to do the computation of U and changing the weights inside your forward pass.
Also Linear layers are not built to be used this way and so will prevent you from making their .weight something else than an nn.Parameter. I guess the simplest here is to do self.hidden1.weight.copy_(torch.ger(self.u, self.u)) in your forward pass and keep the requires_grad=False in the __init__ function.

2 Likes

@albanD
thanks for solution, but I am new to pytorch and I get the following error with your solution

Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

This error means that you computed something once and used it in multiple forward passes and thus multiple backward pass.
Make sure that your forward pass only uses the original tensor that you created with requires_grad=True (a leaf Tensor) and not intermediary results.
In your case, as I said, U should be recomputed at every forward pass !

@albanD
I didn’t get it…
Should I compute U before forward function

U=torch.ger(self.u, self.u)
def forward()
self.hidden1.weight.copy_(U)

or inside

def forward()
U=torch.ger(self.u, self.u)
self.hidden1.weight.copy_(U)

You should compute it inside.

If it still give you an error please post a small code sample that reproduces the issue, that will be simpler to help.

@albanD Here is the code after your suggestion

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        dim=100
        
        self.hidden1 = nn.Linear(dim, dim, bias=False)
        self.output = nn.Linear(dim, 10)

        self.hidden1.weight.requires_grad = False
        self.u = torch.nn.Parameter(torch.randn(dim), requires_grad=True)
        
    def forward(self, x, W):
        U = torch.ger(self.u,self.u)
        self.hidden1.weight.copy_(U*W)

        x = self.hidden1(x)
        x = self.output(x)
    return F.log_softmax(x,dim=1)

model = Net()
model = model.cuda()
output = model.forward(data, W)
loss = F.nll_loss(output, target)
loss.backward() 

Now the error is

Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Now you need to make sure that all your datas are on the same device.
Be careful though that doing y = x.cuda() is an autograd operation and so will cause the same issue of doing backward twice if you use y in your forward pass while x requires gradients.

@albanD
I have made the changes to code (see above) and passed the variable W to forward function directly instead of init(). Now both x and W are on cuda device, but still I am getting the same error.

Is your model on the gpu as well? You need to do model.cuda() as well.

@albanD
Yes it is on gpu.

If you have

model.cuda()
data = data.cuda()
tareget = target.cuda()
W = W.cuda()

In your code then I’m not sure where the error can come form…

@albanD

Yes I have everything on cuda device as you mentioned, but still the same error. But thanks anyway for prompt responses. May be I need to do something else to make this work (not sure right now).

I found a way that works in general with autograd when replacing parameters with some tensor for backpropagation through that tensor: instead of

self.hidden1.weight.copy_(U*W)

use

self.hidden1._parameters['weight'] = U*W

module._parameters[<name>] = <tensor> sidesteps nn.Parameter and, as far as I can tell, just treats the tensor normally with autograd.