Hi,
When calling the backward()
method, is it possible to apply gradient flow information to multiple leaves? It seems like autograd likes to register the necessary updates to the optimizer of a single leaf node along the computation graph.
The following code summarizes what I want (outputs the correct updates after optimizer.step()
, optimizer2.step()
and optimizer3.step()
), but rather in one forward and backward pass, dedicated to params
, params2
, and params3
. Instead of one for each.
# custom, homemade network instantiation
network = MyNetwork()
# some additional parameters (not important)
fc_weights = torch.tensor([[1,0],[0,1]],requires_grad=False,dtype=torch.float)
input_mags = torch.tensor([[2],[3]],requires_grad=False,dtype=torch.float)
# learnable parameters
params = torch.tensor([[0.25],[0.125]],requires_grad=True,dtype=torch.float)
params2 = torch.rand(2,1,requires_grad=True)
params3 = torch.rand(2,1,requires_grad=True)
# just some sample large learning rate to probe changes
learning_rate = 1000
# pass dedicated to `params`
#####
optimizer = optim.SGD([params], lr=learning_rate)
optimizer.zero_grad()
x = torch.hstack((input_mags,params))
out = network(x,fc_weights)
params2.data = out.data
out = network(out,fc_weights)
params3.data = out.data
loss = MyLoss(out)
loss.backward()
#####
# pass dedicated to `params2`
#####
optimizer2 = optim.SGD([params2],lr=learning_rate)
optimizer2.zero_grad()
out = network(params2,fc_weights)
loss = MyLoss(out)
loss.backward()
#####
# pass dedicated to `params3`
#####
optimizer3 = optim.SGD([params3],lr = learning_rate)
optimizer3.zero_grad()
loss = MyLoss(params3)
loss.backward()
#####
# prints correct gradients
print(params.grad)
print(params2.grad)
print(params3.grad)
optimizer.step()
optimizer2.step()
optimizer3.step()
# prints correct updates
print(params)
print(params2)
print(params3)
Additional info:
The following is what I tried doing, leading to unsuccessful attempts. The first attempt makes sense why it doesn’t work, as it breaks the computation graph, with the leaf node where the gradient flow makes a last stop is params3
. The second attempt also makes sense why it doesn’t work, as everything gets recorded only for the last leaf params
, while params2
and params3
just act as intermediate placeholders.
Attempt 1:
# custom, homemade network instantiation
network = MyNetwork()
# some additional parameters (not important)
fc_weights = torch.tensor([[1,0],[0,1]],requires_grad=False,dtype=torch.float)
input_mags = torch.tensor([[2],[3]],requires_grad=False,dtype=torch.float)
# learnable parameters
params = torch.tensor([[0.25],[0.125]],requires_grad=True,dtype=torch.float)
params2 = torch.rand(2,1,requires_grad=True)
params3 = torch.rand(2,1,requires_grad=True)
# just some sample large learning rate to probe changes
learning_rate = 1000
optimizer = optim.SGD([params], lr=learning_rate)
optimizer2 = optim.SGD([params2],lr=learning_rate)
optimizer3 = optim.SGD([params3],lr = learning_rate)
optimizer.zero_grad()
optimizer2.zero_grad()
optimizer3.zero_grad()
x = torch.hstack((input_mags,params))
out1 = network(x,fc_weights)
params2.data = out1.data
params2.retain_grad()
out2 = network(params2,fc_weights)
params3.data = out2.data
params3.retain_grad()
loss = MyLoss(params3)
loss.backward()
# prints gradients
print(params.grad) # is None, which is not what I want
print(params2.grad) # is None, which is not what I want
print(params3.grad) # prints correctly
optimizer.step()
optimizer2.step()
optimizer3.step()
print(params) # no gradient update, which is not what I want
print(params2) # no gradient update, which is not what I want
print(params3) # gradient updates value correctly
Attempt 2:
# custom, homemade network instantiation
network = MyNetwork()
# some additional parameters (not important)
fc_weights = torch.tensor([[1,0],[0,1]],requires_grad=False,dtype=torch.float)
input_mags = torch.tensor([[2],[3]],requires_grad=False,dtype=torch.float)
# learnable parameters
params = torch.tensor([[0.25],[0.125]],requires_grad=True,dtype=torch.float)
params2 = torch.rand(2,1,requires_grad=True)
params3 = torch.rand(2,1,requires_grad=True)
# just some sample large learning rate to probe changes
learning_rate = 1000
optimizer = optim.SGD([params], lr=learning_rate)
optimizer2 = optim.SGD([params2],lr=learning_rate)
optimizer3 = optim.SGD([params3],lr = learning_rate)
optimizer.zero_grad()
optimizer2.zero_grad()
optimizer3.zero_grad()
x = torch.hstack((input_mags,params))
params2 = network(x,fc_weights)
params2.retain_grad()
params3 = network(params2,fc_weights)
params3.retain_grad()
loss = MyLoss(params3)
loss.backward()
# prints gradients correctly
print(params.grad)
print(params2.grad)
print(params3.grad)
optimizer.step()
optimizer2.step()
optimizer3.step()
print(params) # gradient updates values correctly
print(params2) # no gradient update
print(params3) # no gradient update