Trying to backward through the graph a second time. Model composition

Hi everyone,
I have two models, say model_a, model_b.
model_b takes as input the output of model_a as well as some other additional input.
For each input to model_a there are several (large amount) of additional inputs to model_b.
I want to make a gradient step for model_b’s weights for each of the additional inputs while accumulating gradients on model_a’s weights and only making a gradient step after all additional inputs for model_b have been forwarded.
I want to achieve something like this:

examples = 10
additional_input_per_example = 100
batch_size = 32
hidden_dim = 3

for i in range(examples):
  x1 = torch.randn(batch_size, hidden_dim)
  enc = model_a(x1) # Assume this takes 1 hour
  for j in range(additional_input_per_example):
    x2 = torch.randn(batch_size, hidden_dim)
    y = torch.randn(batch_size, hidden_dim)
    y_hat = model_b(enc, x2) # Assume this takes 5 seconds
    loss = criterion(y_hat, y)

If I move the calculation of “emb” into the inner loop it works, but is highly inefficient due to recomputation of “emb”.
Is using the option retain_graph=True absolutelly essential?
This code works when using this option, but in my actual case this increases the memory consumption several times over and seems unnecessary, since I only need to accumulate gradients in model_a weights and I am not using the graph for anything else.
Thanks in advance for your help

In your current code snippet you are calculating the gradients w.r.t. model_a’s parameters multiple times inside the nested for loop. I.e. loss.backward() is attached to model_b and via enc to model_a. This will raise the error since intermediates in model_a will be freed in the first backward call.
If you want to accumulate gradients in model_a you could use retain_graph=True. If not, you can .detach() enc before passing it to model_b.

1 Like

Thanks Peter for your reply.
Using retain_graph=True on the full graph seemed unnecessary, but your reply from a different thread helped me resolve the issue: