How to set gradients manually and update weights using that?

I have a situation that I compute the weights manually and want to update the weights using those. Here is what I did:

optimizer.zero_grad()
param.grad = Variable(grad_tensor)
optimizer.step()

But the weights are not updated, is there something I’m missing?

1 Like

1.Are you 100% sure that there is no difference in param after the step?, maybe you use an adaptive optimizer and the learning rate is almost zero?
2.Are you 100% sure that the optimizer is responsible for this param update?

I wrote this code to test this: (I am using v0.4.0):

import torch

target = torch.randn(1, 10)
input = torch.randn(1, 20)
network = torch.nn.Linear(20, 10)
optimizer = torch.optim.SGD(network.parameters(), lr=0.01)
optimizer.zero_grad()
loss = ((network(input)-target)**2).sum()
loss.backward()
print(network.weight[0, 0])
optimizer.step()
print(network.weight[0, 0]) #works(it is changed)
network.weight.grad[0, 0] = 100
optimizer.step()
print(n.weight[0, 0]) #works(it is changed)
network.weight.grad = torch.randn(10, 20)
optimizer.step()
print(n.weight[0, 0]) #works(it is changed)
6 Likes

Thanks for your answer, I was indeed doing the same, though I was using setattr() to set the values dynamically, I found a bug in my logic. It is working fine now!

Do you have any idea of how to make it work for the following scenario?

import torch

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.l1 = nn.Linear(20, 15)
        self.l2 = nn.Linear(15, 10)
    
    def forward(self, input):
        out = self.l1(input).detach()
        return self.l2(out)

target = torch.randn(1, 10)
input = torch.randn(1, 20)
optimizer = torch.optim.SGD(network.parameters(), lr=0.01)
optimizer.zero_grad()
network = Network()
loss = ((network(input)-target)**2).sum()
loss.backward()
print(network.l1.weight[0, 0])
optimizer.step()
print(network.l1.weight[0, 0]) #doesn't work(it is not changed)
network.l1.weight.grad = torch.randn(15, 20)
optimizer.step()
print(network.l1.weight[0, 0]) #doesn't work(it is not changed)

1 Like