Understanding create_graph=True and autograd free graph

I’m trying to understand what tensors does autograd free after calling .backward(), in the below code, case 1 works as expected, but case 3 doesn’t work unless adding create_graph=True as case 2. This looks odd to me because the gradient is supposed to be computed over model1’s parameters but it seems that model2’s parameters are freed or something like that happens.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class A(nn.Module):
    def __init__(self):
        super(A, self).__init__()
        self.layer1 = nn.Linear(4, 20)
        self.layer2 = nn.Linear(20, 1)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x


class B(nn.Module):
    def __init__(self):
        super(B, self).__init__()
        self.layer1 = nn.Linear(4, 20)
        self.layer2 = nn.Linear(20, 1)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x


# case 1
if __name__ == '__main__':
    model1 = A()
    opt1 = torch.optim.Adam(model1.parameters(), lr=1e-3)
    model2 = B()
    data = torch.randn(100, 4)
    label = torch.randn(100, 1)
    for epoch in range(10):
        print("{}th backward...".format(epoch))
        extra = model2(data)
        loss = torch.mean((extra - label - model1(data)).pow(2))
        opt1.zero_grad()
        loss.backward()
        opt1.step()

# case 2
if __name__ == '__main__':
    model1 = A()
    opt1 = torch.optim.Adam(model1.parameters(), lr=1e-3)
    model2 = B()
    data = torch.randn(100, 4)
    label = torch.randn(100, 1)
    extra = model2(data).detach()
    for epoch in range(10):
        print("{}th backward...".format(epoch))
        loss = torch.mean((extra - label - model1(data)).pow(2))
        opt1.zero_grad()
        loss.backward()
        opt1.step()

# case 3
if __name__ == '__main__':
    model1 = A()
    opt1 = torch.optim.Adam(model1.parameters(), lr=1e-3)
    model2 = B()
    data = torch.randn(100, 4)
    label = torch.randn(100, 1)
    extra = model2(data)
    for epoch in range(10):
        print("{}th backward...".format(epoch))
        loss = torch.mean((extra - label - model1(data)).pow(2))
        opt1.zero_grad()
        loss.backward()
        opt1.step()

Hi,

For case 3 you don’t actually need create_graph=True, only retain_graph=True right?

It’s because, your parameters in model2 still require gradients. So when extra is computed, it requires gradient. When extra is then used in the loop, this same section of the graph (corresponding to the computations that happened in model2) are shared between the different iterations. And so you need to tell the engine not to free them if you need to reuse them.

If in your case, you don’t need gradient when doing the forward of a given model, you have 2 options:

  • If it’s temporary and the input does not require gradients, you can encase the forward in a no_grad block:
with torch.no_grad():
    extra = model2(data)
  • If it’s permanent, you can also iterate over each parameter in model2 and mark it as requires_grad=False.
1 Like

Thanks for confirming that calling backward() on model1 will also free model2.
Your solutions are great but I want to keep model2 untouched because later I need to backward through its parameters, in order to achieve this goal, I tried retain_graph=True when calling backward() on model1, but this leads to memory explosion after many epochs.
How to keep model2 requires_grad without causing memory problem?

Then you need to temporarily disable the autograd while you work with model2 but don’t want gradients:

with torch.no_grad():
    extra = model2(data)

And you can remove retain_graph=True.