Computation graph is interrupted when calling a model within a model

Hi all,

I am trying to do the following:

  1. Create a model (model A) with a few linear layers that produces X, a list of tensors
  2. Create a model (model B) with a few linear layers during the forward loop of model A (or create model B one time outside of the training loop and just reference it here each time)
  3. Overwrite the parameters of model B with X
idx = 0
for n, p in model_B.named_parameters():
    w = X[idx].reshape(128, 128)
    idx += 1 
    p.data = w.float() 
  1. Run model B on some input Y given the new parameters
out = model_B(Y)
loss = F.l1_loss(out, ground_truth)
  1. Finally, call loss.backward() in the forward loop of model A and update the parameters of model A

Currently, the gradient of the parameters of model A is None. Maybe the graph is interrupted when I try to replace the parameters of model B with X as p.data = w.float() is implicitly performed with no_grad()? I’m not sure if this is the issue though.

Some additional information:

  1. X[0]…X[t] all have requires_grad == True
  2. all of model_B parameters have is_leaf == True and requires_grad == True
  3. if I comment out the code for model B and run something simple such as loss = X[0].mean(), then the gradients of the parameters of model A are valid, so the problem probably isn’t with model A itself

Any advice would be greatly appreciated, thanks!

Hi gchou!

Models typically have Parameters and Parameters are typically things
that are optimized.

If I understand your use case, you don’t intend to optimize the “parameters”
(weights) of model B, but rather derive (“predict”) them using another model,
model A (whose Parameters you do intend to optimize).

There are a lot of ways of packaging this, but I think that writing model B
as a function, rather than structuring is as an “official model” (i.e., Module),
would be the cleanest approach and would communicate more clearly what
you are doing.

The following abbreviated pseudo-code illustrates this approach:

def functional_model_b (input, weights):   # takes an input tensor and a list of four weight tensors
    t = torch.nn.functional.linear (input, weight = weights[0], bias = weights[1]
    t = torch.nn.functiona.relu (t)
    t = torch.nn.functional.linear (t, weight = weights[2], bias = weights[3])
    return t

model_b_weights = model_A (some_input)   # returns list of four tensors
pred = functional_model_b (some_other_input, model_b_weights)
loss = criterion (pred, targ)
loss.backward()   # backpropagates through functional_model_b and computes gradients for model_A's parameters

Best.

K. Frank

1 Like

Hi K. Frank,

Thanks so much for the response, this would work perfectly for what I need to do!
Much cleaner than what I was going for as well :slight_smile: