I’m trying out meta learning in PyTorch, and I’m having trouble propagating gradients through the parameters of an nn.Module. Here’s a simplified example that’s not working as I had expected:

import torch
from torch import nn
net = nn.Linear(2, 1)
x = torch.tensor([[1., 2.]])
y = torch.tensor([[2.]])
loss = (net(x) - y)**2
grads = torch.autograd.grad(loss, net.parameters(), create_graph=True)
net2 = nn.Linear(2, 1)
net2.weight.data = net2.weight.data - 0.01*grads[0]
net2.bias.data = net2.bias.data - 0.01*grads[1]
loss2 = (net2(x) - y)**2
loss2.backward()
print net2.weight.grad # This is populated as expected
print net.weight.grad # This is 'None', which is not what I'd like

The gradients of net are stored in grads.
If you call loss.backward() the parameter gradients will be populated and you will get the same gradients as in grads in your last print statement.

Yes, but I’d like gradients of loss2 with respect to the original parameters (i.e… dloss2_by_dnetweight and not dloss_by_dnetweight).

To be clear, there’s no problem if torch.nn.functional is used with manually defined weights like this:

import torch
from torch import nn
from torch.nn import functional as F
w = torch.tensor([[1., 2.]], requires_grad=True)
x = torch.tensor([[3., 2.]])
y = torch.tensor([[2.]])
f = F.linear(x, w)
loss = (f - y)**2
grads = torch.autograd.grad(loss, w, create_graph=True)
w2 = w - 0.01*grads[0]
f2 = F.linear(x, w2)
loss2 = (f2 - y)**2
loss2.backward()
print w.grad # This works as expected.

My problem is with propagating gradients through nn.Parameter of an nn.Module.