(Toy) Meta learning example not working

I am trying to implement a toy meta learning example. Below is code that demonstrates what I am trying to do:

import torch
import torch.nn as nn

torch.random.manual_seed(1)


net = nn.Linear(1, 1, bias=False)
# print("net parameter is ", list(net.parameters())[0])
optim = torch.optim.SGD(net.parameters(), lr=1e-3)
meta_net = nn.Linear(1, 1, bias=False)
# print("meta net parameter is ", list(meta_net.parameters())[0])
meta_optim = torch.optim.SGD(meta_net.parameters(), lr=1e-3)

x = torch.randn((1, 1))
y = net.forward(x)
z = meta_net.forward(x)

optim.zero_grad()
loss = y * z
loss.backward(retain_graph=True)
for p in meta_net.parameters():
    print(p.grad)
optim.step()
# print("net parameter is ", list(net.parameters())[0])

x = torch.randn((1, 1))
z = net.forward(x)
meta_optim.zero_grad()
for p in meta_net.parameters():
    print(p.grad)
loss = z
loss.backward(retain_graph=True)
for p in meta_net.parameters():
    print(p.grad)

If we define the weight of net to be w1 and the weight of meta_net to be w2, then the first parameter update should make update the w1 parameter to be w1' = w1 - lr * w2 * x^2, i.e. a function of w2. Now, when I use the updated net parameter with weight w1' to obtain z, z should be a function of w2. However, when I back propagate through z the gradient of w2 is 0. (the print functions confirm this). What is going wrong here?

Note that I was initially trying with retain_graph=False in the .backward() call and thought this may have been what was causing the issue, but alas it still does not work.