I am using U-net modified as 3D Convolution version. The problem here is that pytorch takes a very large memory size for my 3D U-net. Initially, I only set the training batch size to 2 because of this problem.
torch.cuda.memory_allocated()
outputs low memory usage (around 5GB), but torch.cuda.max_memory_allocated()
outputs high memory usage (around 36GB).
At first, I thought it was a problem caused by the excessively large size of 3D features stored due to U-net’s skip connection, but it wasn’t. I found that just by performing a 3D convolution of size 128 x 128 x 128, Pytorch allocates a huge cache in memory that is not used by subsequent processes.
In order to build the network deeper and increase the number of batches, I solved the problem by putting torch.cuda.reset_max_memory_allocated(device)
, torch.cuda.empty_cache()
in the middle of the forward() function, but I don’t think this is nice solution…
Doing torch.cuda.empty_cache()
slows down the network.
- Why does this cache problem occur?
- Is there any better solution?