PyTorch runtime error: retain_graph=True

Hello everyone.

I have 2 different jupyter notebooks that almost do the same thing.

In the first notebook I do the following:
a) I downloaded vgg16
b) Pass my images to vgg16 and take the output vector from an intermidiate layer
c) Create a matrix with all those vectors.
d) Create my own MLP and train it with this matrix.

In the second notebook I do the exactly same thing as the previous one, but the only difference is that instead of the vgg16 model I have resnet50 with the pre-trained weights of MoCo model. All the other code is the same.

My problem is that when I train my MLP with resnet50 MoCo model I have the following error: “Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time”.

So after a search I did, I change loss.backward() to loss.backward(retain_graph=True) and it worked.
But the problem is, that it is running very very slow.

I do not understand why this is happening when I have resnet50 MoCo model and not with vgg16.

Can anyone help me how to fix this?

My train function was the following for both notebooks in the beginning:

def train_model(model, data, actual_labels):
    model.train()
    
    optimizer.zero_grad()
    
    predicted_training_labels = model(data)

    loss = loss_function(predicted_training_labels, actual_labels)
    
    accuracy = compute_accuracy(predicted_training_labels, actual_labels)

    loss.backward()

    optimizer.step()
    
    return loss.item()

Thanks a lot.

It might be you are trying to backward through resnet as well. It is hard to guess without knowledge of how exactly you collect the features and what data argument represent in the train_model function: could be batch or full matrix.

Hmmm, I do not think that I understand what you said.

To help you, my data is a matrix with shape 1000x2048. Every row of the matrix is the output from a intermediate layer of resnet50 MoCo model. First I create the matrix, so when I train my MLP, I do not use resnet50 at all. I use it only in the beginning in order to create the matrix.

I can not reproduce the error using your code and some toy data and model.

def train_model(model, data, actual_labels):
    model.train()
    
    optimizer.zero_grad()
    
    predicted_training_labels = model(data)

    loss = loss_function(predicted_training_labels, actual_labels)
    
    # accuracy = compute_accuracy(predicted_training_labels, actual_labels)

    loss.backward()

    optimizer.step()
    
    return loss.item()

data = torch.randn(1000, 2048)
labels = torch.empty(1000, dtype=torch.long).random_(3)
loss_function = nn.CrossEntropyLoss()
model = nn.Linear(2048, 3)
optimizer = torch.optim.Adam(model.parameters())

for i in range(10):
    l = train_model(model, data, labels)
    print(l)

This code is working fine and without any error.

Yes and for me when I use vgg16 it works perfect without error and very fast.

Can you explain to me what do you mean with this sentence “It might be you are trying to backward through resnet as well.”

I believe the problem is something with resnet50 MoCo model because it is the only thing that I change.

It would be helpful to take a look on the code there you collect the features from resnet. “backward through resnet” means: if you collect your features then resnet in training mode and do it without with torch.no_grad context your matrix will be connected to resnet computational graph.

Consider this code:

rn18 = models.resnet18()
rn18.fc = nn.Linear(512, 2048)
images = torch.randn(1000, 3, 32, 32)
data = rn18(images)

data = rn18(images)
labels = torch.empty(1000, dtype=torch.long).random_(3)
loss_function = nn.CrossEntropyLoss()
model = nn.Linear(2048, 3)
optimizer = torch.optim.Adam(model.parameters())

for i in range(10):
    l = train_model(model, data, labels)
    print(l)

It will fail, because it is traying to propogate through resnet as wel as through MLP. Also single pass takes much longer now.

And using those changes:

rn18 = models.resnet18()
rn18.fc = nn.Linear(512, 2048)
rn18.eval()
images = torch.randn(1000, 3, 32, 32)
with torch.no_grad():
    data = rn18(images)

labels = torch.empty(1000, dtype=torch.long).random_(3)
loss_function = nn.CrossEntropyLoss()
model = nn.Linear(2048, 3)
optimizer = torch.optim.Adam(model.parameters())

for i in range(10):
    l = train_model(model, data, labels)
    print(l)

Everything is working fine.
Also, you can do data.detach_() on your matrix.

If I do detach_() in my data, I realize that my new model (train_model(model, data, labels)) does not update its parameters. So eventually I do not train it.

True. You should detach it only once and only after collecting the features using the Resnet. Then, you can train as usual. The code snippet in my previous post is the preferred way to do so (in my opinion), as it is more explicit: you don’t need to train Resnet - set it to eval, you don’t want pytorch to compute gradients while you extracting the features - use with torch.no_grad() context.

1 Like

Thank you very much.
You were right. With with torch.no_grad(), worked perfect.

Great to hear it helps :slight_smile:
All the best!