I am trying to implement a toy meta learning example. Below is code that demonstrates what I am trying to do:
import torch
import torch.nn as nn
torch.random.manual_seed(1)
net = nn.Linear(1, 1, bias=False)
# print("net parameter is ", list(net.parameters())[0])
optim = torch.optim.SGD(net.parameters(), lr=1e-3)
meta_net = nn.Linear(1, 1, bias=False)
# print("meta net parameter is ", list(meta_net.parameters())[0])
meta_optim = torch.optim.SGD(meta_net.parameters(), lr=1e-3)
x = torch.randn((1, 1))
y = net.forward(x)
z = meta_net.forward(x)
optim.zero_grad()
loss = y * z
loss.backward(retain_graph=True)
for p in meta_net.parameters():
print(p.grad)
optim.step()
# print("net parameter is ", list(net.parameters())[0])
x = torch.randn((1, 1))
z = net.forward(x)
meta_optim.zero_grad()
for p in meta_net.parameters():
print(p.grad)
loss = z
loss.backward(retain_graph=True)
for p in meta_net.parameters():
print(p.grad)
If we define the weight of net
to be w1 and the weight of meta_net
to be w2, then the first parameter update should make update the w1 parameter to be w1' = w1 - lr * w2 * x^2
, i.e. a function of w2
. Now, when I use the updated net
parameter with weight w1'
to obtain z
, z
should be a function of w2
. However, when I back propagate through z
the gradient of w2
is 0. (the print functions confirm this). What is going wrong here?
Note that I was initially trying with retain_graph=False
in the .backward()
call and thought this may have been what was causing the issue, but alas it still does not work.