Backward on loss function; the parameter doesn't update

Hi; I’m playing around with this toy exmaple

from torch import nn, optim
net1 = torch.nn.Linear(1,2)
net2 = torch.nn.Linear(1,1)

x = torch.Tensor([[1.]])
z = torch.Tensor([[2.]])

x.requires_grad=True
z.requires_grad=True

optimizer = optim.Adam(net1.parameters(), lr=1e-3)
Loss = []
for _ in range(50):
    optimizer.zero_grad()

    y = net1(x) # get the output from first neural network

    net2.weight.data = y[0][0].reshape(-1,1) # using the output from first net as parameter to second net
    net2.bias.data = y[0][1].reshape(-1,1)
    pred = net2(z)

    ## above can also be written as pred = y[0][0].reshape(-1,1) @ z + y[0][1].reshape(-1,1) ## which I think can work
    loss = (pred - 2).pow(2).sum()


    loss.backward()
    print(loss.grad) # prints None value for either approach above
    optimizer.step()
    
    Loss.append(loss.item())
plt.plot(Loss)

Above method is simply pred = [z,1] @ (W1*x + b1) where @ is dot product of two vector; and I try to make the network to output 2.

But clearly; the loss.backward() doesn’t give proper value with first approach. How do I fix it ? I kind of want to make it work for nn.Conv2d so I can avoid manually compute the convolutionary layer

Hi,

loss.grad is expected to remain None as only the gradients for leaf variables are saved in the .grad field.

These two lines

net2.weight.data = y[0][0].reshape(-1,1) # using the output from first net as parameter to second net2.bias.data = y[0][1].reshape(-1,1)    

Use .data and thus break the computational graph. Meaning that no gradients will flow back to n1.

If you want to set these weights, the right way to do this is to delete the existing parameters with del net2.weight and then set the field to the tensor that you want: net2.weight = y[0][0].reshape(-1,1).

Hi I use net2.weight = y[0][0].reshape(-1,1) raises Error say cannot assign 'torch.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

So I switch to net2.weight = torch.nn.Parameter(y[0][0].reshape(-1,1)) but still doesn’t update the value :frowning:

The second one does not update because the call to nn.Parameter() breaks the graph as well.

Re-read my answer: you need to del net2.weigh before-hand (just after creating the model) and then you will be able to assign it a Tensor that requires gradients.

Thanks; I wasn’t too careful when reading and missing the delete part. Thanks for the help ! :slight_smile:

1 Like