Keeping loaded parameters in computation graph of forward pass


my question is whether it is possible to load parameters into a model, such that they are still in the computation graph of the result of the forward function of said model.

Let’s say for example

import torch
import torch.nn as nn
from torch.autograd import grad

class SimpleNet(nn.Module):
  def __init__(self):

    self.fc1 = nn.Linear(4, 1)

  def forward(self, x):
    return self.fc1(x)

model = SimpleNet()

new_theta = {n: p * 2 for n, p in model.named_parameters()}

x = torch.randn(4)
result = model(x)

grad(result, [p for p in new_theta.values()], allow_unused=True)

The grad at the end gives None for all tensors, because they are detached from the computation graph that calculates result, i.e. the forward function (that’s what I think at least).

Is it somehow possible to load parameters into the model so that this gradient will get populated ?

load_state_dict will load the parameters inplace in a no_grad context and will thus not track the gradients to new_theta. The parameters of the model are still used and show valid gradients:

grad(result, [p for p in model.parameters()], allow_unused=True)
# (tensor([[ 0.4696, -0.3762, -0.4090, -1.1362]]), tensor([1.]))

What’s your use case exactly and why do you want to see the gradients in new_theta, but not the actually used parameters (you could still copy the gradients if needed)?

First of all, thanks for the answer.

My use case is an implementation of MAML.

I have a theta (the parameters), which I want to use to calculate forward passes. Basically, I want to tell the model “do a normal forward pass, but use these parameters”, so that the computation graph includes the parameters passed.

One way to achieve what I want, is to do something like:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad

class SimpleNet(nn.Module):
  def __init__(self):

    self.fc1 = nn.Linear(4, 1)

  def forward(self, x, theta):
    return F.linear(x, theta[0], theta[1])

or similar.

In this case I would define the model in the self.fc1 to still access things like initial parameters via model.parameters(), but the actual forward pass is calculated via “external” parameters.

But this becomes very cumbersome for complicated models → it would be nice to handle this differently

You might want to try torch.func.functional_call, which allows you to:

Performs a functional call on the module by replacing the module parameters and buffers with the provided ones.

Okay, that sounds like just what I need.

Thanks! I will try it out later and come back here to accept the solution or ask again.

I tried it out and it does just what I want.

I have one question regarding buffers/in-place operations however, which was not answered in the documentation.

If I have the following net:

class SimpleNet(nn.Module):
  def __init__(self):
    super().__init__() = nn.Sequential(
        nn.Conv2d(3, 12, 3),
        nn.Conv2d(12, 1, 3),
        nn.Linear(36, 1)

  def forward(self, x):

and use it like this:

model = SimpleNet()
x = torch.randn(5, 3, 10, 10)

external_params = {n: p*2 for n, p in model.named_parameters()}
external_buffers = {n: torch.ones_like(b) for n, b in model.named_buffers()}

result = functional_call(model, (external_params, external_buffers), x)

Then, I can see that the external_buffers are updated (as described in the documentation). However, I was surprised that

result = functional_call(model, external_params, x) 

works as well. In this case, I assume the buffers of the model are used and updated (f.e. in the case of BatchNorm ?

Yes, the internal parameters and buffers will be used if the strict=False argument is passed (used by default). This example also demonstrates it:

class SimpleNet(nn.Module):
  def __init__(self):

    self.lin1 = nn.Linear(2, 2)
    self.lin2 = nn.Linear(2, 2)

  def forward(self, x):
    return self.lin1(x) + self.lin2(x)

model = SimpleNet()
x = torch.randn(1, 2)

external_params = {n: p*20000 for n, p in model.named_parameters() if "1" in n}
result = torch.func.functional_call(model, external_params, x, strict=False)


# tensor([[-4179.4907, -9056.8799],
#         [-4179.4907, -9056.8799]])
# tensor([[-0.2090, -0.4528],
#         [-0.2090, -0.4528]])

Okay, got it, perfect!

Thanks und Danke Dir!

1 Like