How to back-propagate gradient through model parameters?

I tried to implement a mini-example for MAML, which involves back-propagate through gradients.

Mathematically, I consider a simple linear model ‘output = param * input’ with MSELoss. I would like to do the following update:

param_new = param - step1 * gradient wrt param of MSELoss(param, train_data)
param = param - step2 *gradient wrt param of MSELoss(param_new, test_data)

In particular, the term ‘gradient wrt param of MSELoss(param_new, test_data)’ is crucial.

I use the following code to compute the desired gradient wrt param of MSELoss(param_new, test_data).

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.f = nn.Linear(1,1,bias=False)
        self.init_weight()
    def forward(self, x):
        return self.f(x)
    def init_weight(self):
        nn.init.constant_(self.f.weight, 3.)
m = MyModel()
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

x, t = torch.tensor([2.],requires_grad=True), torch.tensor([2.],requires_grad=True)
x_n, t_n = torch.tensor([3.],requires_grad=True), torch.tensor([3.],requires_grad=True)

y = m(x)
l = loss_fn(y, t)
optimizer.zero_grad()
l.backward(create_graph=True)

for p in m.parameters():
    p_n = p - 0.1*p.grad

y_n = torch.nn.functional.linear(x_n,p_n)
l_n = loss_fn(y_n, t_n)
optimizer.zero_grad()
l_n.backward()
print(m.f.weight.grad)

I would like to use the model ‘m’ to evaluate ‘y_n’ and ‘l_n’, which corresponds to the following code.

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.f = nn.Linear(1,1,bias=False)
        self.init_weight()
    def forward(self, x):
        return self.f(x)
    def init_weight(self):
        nn.init.constant_(self.f.weight, 3.)
m = MyModel()
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

x, t = torch.tensor([2.],requires_grad=True), torch.tensor([2.],requires_grad=True)
x_n, t_n = torch.tensor([3.],requires_grad=True), torch.tensor([3.],requires_grad=True)

y = m(x)
l = loss_fn(y, t)
optimizer.zero_grad()
l.backward(create_graph=True)

state_dict = m.state_dict()
for n, p in m.named_parameters():
    state_dict[n] = p - 0.1*p.grad

m.load_state_dict(state_dict)    
y_n = m(x_n)
l_n = loss_fn(y_n, t_n)
optimizer.zero_grad()
l_n.backward()
print(m.f.weight.grad)

However, the second snippet cannot compute the desired gradient. From the printed information, it seems that instead of computing ‘gradient wrt param of MSELoss(param_new, test_data)’, ‘gradient wrt param_new of MSELoss(param_new, test_data)’.

I would like to construct a more complicated model (beyond a simple linear one) to do a similar meta update. I was wondering if there is an elegant way to compute the desired gradient which uses the model instead of explicitly calling torch.nn.functional?

Thanks a lot.