I am trying to get the PyTorch to work without preserving its computation graph at the background, to save GPU memory. My code is roughly as follows:
with torch.no_grad():
train_X, labels = Variable(train_X, requires_grad=False).to(device), Variable(labels, requires_grad=False).to(device)
conv0 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, bias=True).to(device)
relu = nn.LeakyReLU().to(device)
avgpool = nn.AvgPool2d(kernel_size=2).to(device)
out = conv0(train_X) #GPU memory gets up to 2853M after this line
out = relu(out) #GPU memory gets up to 4769M after this line
out = avgpool(out) #GPU memory gets up to 5249M
My question is that out = relu(out)
seems to preserve the memory earlier obtained using out = conv0(train_X)
, which can be cleared using torch.cuda.empty_cache()
, but I have used the with torch.no_grad():
that should stopping generating the computation graph. I have tried setting the train_X
, conv0.weight
to requires_grad=False
, but they did work. Could you tell me what is wrong? Sorry I don’t know how to adjust the codes
I know this seems to be quite trivial. If there is a duplication, please point out! Thank you!