I am trying to implement MAML https://arxiv.org/abs/1703.03400 and I need to update the parameters of a module while keeping any existing graph dependency information related to the tensors holding the new parameter values. A very simplified example:
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(3, 1)
def forward(self, x):
return self.layer(x)
first_model = MyModule()
second_model = MyModule()
x = torch.normal(mean=torch.zeros([10,3]), std=torch.ones([10,3]))
y = torch.normal(mean=torch.zeros([10,1]), std=torch.ones([10,1]))
output = first_model(x)
loss = ((output - y)**2).mean()
grads = torch.autograd.grad([loss], first_model.parameters(), retain_graph=True, create_graph=True)
new_params = [p - learn_rate * grad for p, grad in zip(first_model.parameters(), grads)]
# TODO: Assing new_params to second_model.parameters() by keeping the dependency of new_params on first_model.parameters() ???
# Something like
# for p, v in zip(second_model.parameters(), new_params):
# p.update(v)
new_x = torch.normal(mean=torch.zeros([10,3]), std=torch.ones([10,3]))
new_y = torch.normal(mean=torch.zeros([10,1]), std=torch.ones([10,1]))
new_output = second_model(new_x)
new_loss = ((new_output - new_y)**2).mean()
new_grads = torch.autograd.grad([new_loss], first_model.parameters())
for param, grad in zip(first_model.parameters(), new_grads):
param.grad = grad
optimizer.step()
I am aware of a workaround which uses nn.Modules which take as input the desired parameters to use for the forward pass and compute the outputs using nn.functional instead of using ready-made modules such as nn.Linear, nn.Conv. However, this is not an option in my case since I have many large and complex models, which use ready-made modules and which I cannot modify.