I tried to implement a mini-example for MAML, which involves back-propagate through gradients.
Mathematically, I consider a simple linear model ‘output = param * input’ with MSELoss. I would like to do the following update:
’
param_new = param - step1 * gradient wrt param of MSELoss(param, train_data)
param = param - step2 *gradient wrt param of MSELoss(param_new, test_data)
’
In particular, the term ‘gradient wrt param of MSELoss(param_new, test_data)’ is crucial.
I use the following code to compute the desired gradient wrt param of MSELoss(param_new, test_data).
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.f = nn.Linear(1,1,bias=False)
self.init_weight()
def forward(self, x):
return self.f(x)
def init_weight(self):
nn.init.constant_(self.f.weight, 3.)
m = MyModel()
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
x, t = torch.tensor([2.],requires_grad=True), torch.tensor([2.],requires_grad=True)
x_n, t_n = torch.tensor([3.],requires_grad=True), torch.tensor([3.],requires_grad=True)
y = m(x)
l = loss_fn(y, t)
optimizer.zero_grad()
l.backward(create_graph=True)
for p in m.parameters():
p_n = p - 0.1*p.grad
y_n = torch.nn.functional.linear(x_n,p_n)
l_n = loss_fn(y_n, t_n)
optimizer.zero_grad()
l_n.backward()
print(m.f.weight.grad)
I would like to use the model ‘m’ to evaluate ‘y_n’ and ‘l_n’, which corresponds to the following code.
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.f = nn.Linear(1,1,bias=False)
self.init_weight()
def forward(self, x):
return self.f(x)
def init_weight(self):
nn.init.constant_(self.f.weight, 3.)
m = MyModel()
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
x, t = torch.tensor([2.],requires_grad=True), torch.tensor([2.],requires_grad=True)
x_n, t_n = torch.tensor([3.],requires_grad=True), torch.tensor([3.],requires_grad=True)
y = m(x)
l = loss_fn(y, t)
optimizer.zero_grad()
l.backward(create_graph=True)
state_dict = m.state_dict()
for n, p in m.named_parameters():
state_dict[n] = p - 0.1*p.grad
m.load_state_dict(state_dict)
y_n = m(x_n)
l_n = loss_fn(y_n, t_n)
optimizer.zero_grad()
l_n.backward()
print(m.f.weight.grad)
However, the second snippet cannot compute the desired gradient. From the printed information, it seems that instead of computing ‘gradient wrt param of MSELoss(param_new, test_data)’, ‘gradient wrt param_new of MSELoss(param_new, test_data)’.
I would like to construct a more complicated model (beyond a simple linear one) to do a similar meta update. I was wondering if there is an elegant way to compute the desired gradient which uses the model instead of explicitly calling torch.nn.functional?
Thanks a lot.