Assign module parameters while retaining existing graph dependencies

I am trying to implement MAML and I need to update the parameters of a module while keeping any existing graph dependency information related to the tensors holding the new parameter values. A very simplified example:

class MyModule(nn.Module):

  def __init__(self):
    self.layer = nn.Linear(3, 1)

  def forward(self, x):
    return self.layer(x)

first_model = MyModule()
second_model = MyModule()

x = torch.normal(mean=torch.zeros([10,3]), std=torch.ones([10,3]))
y = torch.normal(mean=torch.zeros([10,1]), std=torch.ones([10,1]))
output = first_model(x)
loss = ((output - y)**2).mean()

grads = torch.autograd.grad([loss], first_model.parameters(), retain_graph=True, create_graph=True)

new_params = [p - learn_rate * grad for p, grad in zip(first_model.parameters(), grads)]

# TODO: Assing new_params to second_model.parameters() by keeping the dependency of new_params on first_model.parameters() ???
# Something like
# for p, v in zip(second_model.parameters(), new_params):
#   p.update(v)

new_x = torch.normal(mean=torch.zeros([10,3]), std=torch.ones([10,3]))
new_y = torch.normal(mean=torch.zeros([10,1]), std=torch.ones([10,1]))
new_output = second_model(new_x)
new_loss = ((new_output - new_y)**2).mean()

new_grads = torch.autograd.grad([new_loss], first_model.parameters())

for param, grad in zip(first_model.parameters(), new_grads):
  param.grad = grad


I am aware of a workaround which uses nn.Modules which take as input the desired parameters to use for the forward pass and compute the outputs using nn.functional instead of using ready-made modules such as nn.Linear, nn.Conv. However, this is not an option in my case since I have many large and complex models, which use ready-made modules and which I cannot modify.


Does that mean that your second_module does not actually have any learnable parameters?

Yes. I only need to be able to update its parameters (keeping the existing graph dependencies) and do forward passes and backpropagation.

In that case, you can simply replace the fields in question:

second_model.linear.weight = new_weight
second_model.linear.biais = new_bias

You can get these locations automatically by using named_parameters() instead of paramters().