Pytorch no computation graph mode

I am trying to get the PyTorch to work without preserving its computation graph at the background, to save GPU memory. My code is roughly as follows:

with torch.no_grad():
train_X, labels = Variable(train_X, requires_grad=False).to(device), Variable(labels, requires_grad=False).to(device)
conv0 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, bias=True).to(device)
relu = nn.LeakyReLU().to(device)
avgpool = nn.AvgPool2d(kernel_size=2).to(device)
out = conv0(train_X) # GPU memory gets up to 2853M after this line
out = relu(out) # GPU memory gets up to 4769M after this line
out = avgpool(out) # GPU memory gets up to 5249M

My question is that out = relu(out) seems to preserve the memory earlier obtained using out = conv0(train_X), which can be cleared using torch.cuda.empty_cache(), but I have used the with torch.no_grad(): that should stopping generating the computation graph. I have tried setting the train_X, conv0.weight to requires_grad=False, but they did work. Could you tell me what is wrong? Sorry I don’t know how to adjust the codes :sweat_smile:

I know this seems to be quite trivial. If there is a duplication, please point out! Thank you!

Hi,

To prevent the autograd from saving anything, setting with torch.no_grad(): is enough.
The other issue you might have is about how you measure GPU memory as pytorch use a custom allocator and so memory reported by the OS is not necessarily correct.
Also when you do out = relu(out), the relu is first computed, and thus a temporary output is allocated. Then the old content of out is discarded and replaced with the temporary one. So I would expect to see a memory bump here.
Note that most activation functions have an inplace flag to make them perform their update inplace and avoid this issue.

1 Like

Hi. Thanks for the reponse! I tried inplace=True, and it solved the memory increase in out = relu(out). However, the memory issue in out = maxpool(out) persists. Any way to resolve this like the activation case? Pooling should reduce the memory consumption in our case.

Hi,

The pooling cannot happen inplace. So it has to allocate an output to write the result into. So it is expected that during the function, the memory goes up a bit until you can free the input.

1 Like