Element-wise operations between two convolution cause memory leak

Hi! I found that torch.softmax cause GPU memory leak.
My pytorch version is 1.8.1+cu111.
When I run the code below:

import torch
from torch import nn
from torch.nn import functional as F
from torch import cuda 
 
def test(inp): 
  w = torch.rand([32, 1, 1, 1],device='cuda') 
  y = torch.softmax(F.conv2d(inp, w), 1)  
  y = F.conv_transpose2d(y, w) 
  return y

imgs = torch.zeros([128,1,512,512],device='cuda')
outp = test(imgs)  
 
# del outp
cuda.empty_cache()
print(cuda.memory_summary()) 

The output is:


After the function outputs the result, the variables inside the function should be released, but it did not. In addition, if you delete the variableoutp, the redundant occupied memory will be released.
For conparison, I wrote a softmax myself and it did not has the condition of memory leak. The code is below:

import torch
from torch import nn
from torch.nn import functional as F
from torch import cuda 

def softmax(x,dim):
  ex = torch.exp(x)
  return ex/torch.sum(ex, dim,keepdim=True)

def test(inp): 
  w = torch.rand([32, 1, 1, 1],device='cuda') 
  y = softmax(F.conv2d(inp, w), 1)  
  y = F.conv_transpose2d(y, w) 
  return y

imgs = torch.zeros([128,1,512,512],device='cuda')
outp = test(imgs)  
 
cuda.empty_cache()
print(cuda.memory_summary()) 

The output:


I don’t know what mechanism the non-releasable memory is. If it’s a bug, please fix it as soon as possible. Thanks :pleading_face: :pleading_face: :pleading_face:

Another weird phenomenon :expressionless: :expressionless: :expressionless: :expressionless:
I found that as long as I multiply a constant immediately after the first convolution, this kind of memory leak will occur. If you multiply the constant before feed into the ‘conv_transpose2d’ or not multiply the constant, the memory leak will disappear. The codes and results for the three cases are as follows:

import torch
from torch import nn
from torch.nn import functional as F
from torch import cuda  

def test(inp): 
  w = torch.rand([32, 1, 1, 1],device='cuda') 
  a = F.conv2d(inp, w)*5
  y = F.conv_transpose2d(a, w) 
  return y

imgs = torch.zeros([128,1,512,512],device='cuda')
outp = test(imgs)  
 
cuda.empty_cache()
print(cuda.memory_summary()) 


import torch
from torch import nn
from torch.nn import functional as F
from torch import cuda  

def test(inp): 
  w = torch.rand([32, 1, 1, 1],device='cuda') 
  a = F.conv2d(inp, w)
  y = F.conv_transpose2d(a*5, w) 
  return y

imgs = torch.zeros([128,1,512,512],device='cuda')
outp = test(imgs)  
 
cuda.empty_cache()
print(cuda.memory_summary()) 

import torch
from torch import nn
from torch.nn import functional as F
from torch import cuda  

def test(inp): 
  w = torch.rand([32, 1, 1, 1],device='cuda') 
  a = F.conv2d(inp, w)
  y = F.conv_transpose2d(a, w) 
  return y

imgs = torch.zeros([128,1,512,512],device='cuda')
outp = test(imgs)  
 
cuda.empty_cache()
print(cuda.memory_summary()) 


Please help me!! :cry: :cry: :cry:

I don’t know if this is a bug because in your example outp is still in scope and should not be deleted.
However, it is interesting that the native implementation of softmax causes the computation to tie up more memory.

del outp was commented out in the code. I just did not post the result when outp was deleted.

I don’t know why outp takes up so much non-releasable memory when softmax is used, or even just multiply a constant, as I illustrated above.

This isn’t a memory leak as the memory is still available within PyTorch.
Roughly speaking, PyTorch allocates CUDA memory in chunks and then puts tensors in it. It can only “return” (with empty caches) chunks that are completely unused (i.e. don’t have any tensors in them).
What happens here is that with the additional intermediate the data of outp sits inside the 4GB allocation that is mostly free and available to PyTorch but cannot be returned.
You can get a glimpse of this by calling: torch.cuda.memory_snapshot and it also links to a brief note on memory management (e.g. how to avoid caching allocations for debugging):
print(torch.cuda.memory_snapshot())

Best regards

Thomas

2 Likes

Thank you very much for your answer! But I still don’t quite understand. After the function returns, why are the intermediate variables inside the function not released completely? After all, I only need the tensor returned by the function, and I don’t need anything else. And this doesn’t always happen. For example, the penultimate and the antepenultimate examples I gave, just because the timing of multiplying the constant is different, the function takes 4GB more memory after returning.

I think the memory summary is confusing here. Are you seeing 4GB more used, or 4GB that is used and not available for allocations? A litmus test here would be to keep allocating tensors after the function returns and see how many succeed before an OOM error is thrown.

It is true that the non-releasable memory can be reallocated to the new tensor, but when the new tensor is large, the non-releasable memory can not be combined with the released memory to obtain a whole block of memory to hold the new tensor. When the available video memory is 14GB (colab), the following two examples show this situation:

import torch
from torch import nn
from torch.nn import functional as F
from torch import cuda  

def test(inp): 
  w = torch.rand([32, 1, 1, 1],device='cuda') 
  a = F.conv2d(inp, w)*5 # 
  y = F.conv_transpose2d(a, w) # 
  return y
 
outp = test(torch.zeros([128,1,512,512],device='cuda'))  
cuda.empty_cache()
print(cuda.memory_summary())  

a = torch.zeros(11,256,1024,1024,device='cuda')  
print(cuda.memory_summary())  
import torch
from torch import nn
from torch.nn import functional as F
from torch import cuda  

def test(inp): 
  w = torch.rand([32, 1, 1, 1],device='cuda') 
  a = F.conv2d(inp, w) #Only modified here
  y = F.conv_transpose2d(a*5, w) #Only modified here
  return y
 
outp = test(torch.zeros([128,1,512,512],device='cuda'))  
cuda.empty_cache()
print(cuda.memory_summary())  

a = torch.zeros(11,256,1024,1024,device='cuda')  
print(cuda.memory_summary())  

How to solve this problem?

You are right, the non-releasable memory indeed can be reused by pytorch, but it is not shared with the released memory. When the new defined tensor is large, the non-releasable memory can not be applied with released memory simultaneously.

So memory fragmentation (which is likely the more precise than memory leak to describe the situation) is a thing with the caching allocator. If you wanted, you could set PYTORCH_NO_CUDA_MEMORY_CACHING=1 to get around this at the expense of doing all allocations/deallocations through cuda.