3D Convolution Replicate Padding CUDA out of memory

Hello PyTorch community!

I am trying to train a 3D-Conv based model (summary printed below using torchinfo). My input shape looks like (16, 3, 3, 640, 256).

==========================================================================================                                                       
Layer (type:depth-idx)                   Output Shape              Param #                                                                       
==========================================================================================                                                       
Sequential: 1-1                        [16, 1, 1, 32, 80]        --                                                                              
    Conv3d: 2-1                       [16, 64, 3, 128, 320]     5,248                                                                            
    LeakyReLU: 2-2                    [16, 64, 3, 128, 320]     --                                                                               
    Conv3d: 2-3                       [16, 128, 3, 64, 160]     221,184                                                                          
    BatchNorm3d: 2-4                  [16, 128, 3, 64, 160]     256                                                                              
    LeakyReLU: 2-5                    [16, 128, 3, 64, 160]     --                                                                               
    Conv3d: 2-6                       [16, 256, 3, 32, 80]      884,736                                                                          
    BatchNorm3d: 2-7                  [16, 256, 3, 32, 80]      512                                                                              
    LeakyReLU: 2-8                    [16, 256, 3, 32, 80]      --                                                                               
    Conv3d: 2-9                       [16, 512, 3, 32, 80]      3,538,944                                                                        
    BatchNorm3d: 2-10                 [16, 512, 3, 32, 80]      1,024                                                                            
    LeakyReLU: 2-11                   [16, 512, 3, 32, 80]      --                                                                               
    Conv3d: 2-12                      [16, 1, 1, 32, 80]        13,825                                                                           
==========================================================================================                                                       
Total params: 4,665,729                                                                                                                          
Trainable params: 4,665,729                                                                                                                      
Non-trainable params: 0                                                                                                                          
Total mult-adds (G): 663.18                                                                                                                      
==========================================================================================                                                       
Input size (MB): 94.37                                                                                                                           
Forward/backward pass size (MB): 3523.54                                                                                                         
Params size (MB): 18.66                                                                                                                          
Estimated Total Size (MB): 3636.58                                                                                                               
==========================================================================================

When I use zero padding, my model does not face any memory issues, but when I change the padding_mode to replicate I get a CUDA memory error, which seems to arise during padding of the input.

    return torch._C._nn.replication_pad3d(input, pad)
RuntimeError: CUDA out of memory. Tried to allocate 656.00 MiB (GPU 0; 31.75 GiB total capacity; 29.27 GiB already allocated; 632.00 MiB free; 29.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I tried to check the memory allocated to tensors during the forward pass (I make multiple forward passes in a single iteration through the network), and it does seem like with replicate padding it increases significantly more.

All values in GB.

Zero Padding

mem allocated at start of iter 1 0.1781
mem allocated  12.34                                                                                                   
mem allocated  15.80                                                                                                    
mem allocated  19.19                                                                                                    
mem allocated  20.06                                                                                                  
mem allocated  20.90
mem allocated  21.12
mem allocated  21.33
mem allocated  21.39
SUCCESSFULLY COMPLETED iter 1

Replicate Padding

mem allocated at start of iter 1 0.1781
mem allocated  12.34
mem allocated  19.27
mem allocated  26.13
mem allocated  27.90
mem allocated  29.66
mem allocated  30.13
CRASHED AT iter 1

Is this an expected behaviour of replicate padding in 3D Convs or is there anything I can do to improve this?

I cannot reproduce the issue and get the same memory usage using:

import torch
import torch.nn as nn

print(torch.cuda.memory_allocated()/1024**3)
# > 0.0

x = torch.randn(16, 16, 3, 128, 320, device='cuda')
print(torch.cuda.memory_allocated()/1024**3)
# > 0.1171875

conv = nn.Conv3d(16, 64, kernel_size=3, padding=1).cuda()
out = conv(x)
out.mean().backward()
print(torch.cuda.memory_allocated()/1024**3)
# > 0.5861444473266602


del x, conv, out
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated()/1024**3)
# > 0.0

x = torch.randn(16, 16, 3, 128, 320, device='cuda')
print(torch.cuda.memory_allocated()/1024**3)
# > 0.1171875

conv = nn.Conv3d(16, 64, kernel_size=3, padding=1, padding_mode='replicate').cuda()
out = conv(x)
out.mean().backward()
print(torch.cuda.memory_allocated()/1024**3)
# > 0.5861444473266602

Hi, I am providing a code snippet that replicates the behavior that I mentioned. Please note that I also experienced the same issue when using Conv2D. The memory consumption was much higher for replicate padding compared to zero padding.

import torch
import torch.nn as nn
import functools

if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"

print("Using Device:", device)

class Network(nn.Module):
    def __init__(self, input_nc, ndf=64,  n_layers=5, norm_layer=nn.BatchNorm3d, padding_mode="zeros"):
        super(Network, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm3d
        else:
            use_bias = norm_layer == nn.InstanceNorm3d

        kw = 3
        padw = 1
        sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=(1, 2, 2), padding=padw, padding_mode=padding_mode),
                    nn.LeakyReLU(0.2, True)]

        nf_mult = 1
        nf_mult_prev = 1

        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)

            sequence += [
                nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=(1, 2, 2), padding=padw, padding_mode=padding_mode, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)

        sequence += [
            nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, padding_mode=padding_mode, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [
            nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=(0, padw, padw), padding_mode=padding_mode)]  # output 1 channel prediction map

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)

scales = [0, 1, 2, 3]
frames = ["-1", "1"]
padding_mode = ["zeros", "replicate"]
LossFunc = torch.nn.MSELoss()
LossFunc.to(device)

for p_mode in padding_mode:
    print("----------------------------------------------------------------------")
    print("Using padding mode:", p_mode)
    print("----------------------------------------------------------------------")
    torch.cuda.empty_cache()
    print("Process Started Memory:", round(torch.cuda.memory_allocated()/1024**3, 3))
    print("----------------------------------------------------------------------")
    net = Network(input_nc=3, ndf=64, n_layers=3, padding_mode=p_mode)
    net.to(device)

    # Create rand input
    inputs = {}

    for scale in scales:
      for frame in frames:
        inputs["image", frame, scale] = torch.randn(16, 3, 3, 256//2**scale, 640//2**scale, device=device)

    print("Inputs Created Memory:", round(torch.cuda.memory_allocated()/1024**3, 3))
    print("----------------------------------------------------------------------")
    loss = 0

    try:
      for scale in scales:
        for frame in frames:
          print("Scale:", scale, "Frame:", frame, "Memory @ Before Net:", round(torch.cuda.memory_allocated()/1024**3, 3))
          pred = net(inputs["image", frame, scale])
          target = torch.randn_like(pred.detach())
          print("Scale:", scale, "Frame:", frame, "Memory @ After Net:", round(torch.cuda.memory_allocated()/1024**3, 3))
          loss += LossFunc(pred, target)
          print("Scale:", scale, "Frame:", frame, "Memory @ Loss:", round(torch.cuda.memory_allocated()/1024**3, 3))
        print("----------------------------------------------------------------------")
      else:
          print("Success")
          del inputs, pred, target, net, loss

    except:
      raise Exception("Ran Out Of Memory.")

The corresponding output that I got on PyTorch (1.5 and 1.9) was something like:

Using Device: cuda                                                                                                                     
----------------------------------------------------------------------                                                                 
Using padding mode: zeros                                                                                                              
----------------------------------------------------------------------                                                                 
Process Started Memory: 0.0                                                                                                            
----------------------------------------------------------------------                                                                 
Inputs Created Memory: 0.251                                                                                                           
----------------------------------------------------------------------                                                                 
Scale: 0 Frame: -1 Memory @ Before Net: 0.251                                                                                          
Scale: 0 Frame: -1 Memory @ After Net: 1.892                                                                                           
Scale: 0 Frame: -1 Memory @ Loss: 1.892                                                                                                
Scale: 0 Frame: 1 Memory @ Before Net: 1.892                                                                                           
Scale: 0 Frame: 1 Memory @ After Net: 3.533                                                                                            
Scale: 0 Frame: 1 Memory @ Loss: 3.533                                                                                                 
----------------------------------------------------------------------                                                                 
Scale: 1 Frame: -1 Memory @ Before Net: 3.533                                                                                          
Scale: 1 Frame: -1 Memory @ After Net: 3.943                                                                                           
Scale: 1 Frame: -1 Memory @ Loss: 3.943                                                                                                
Scale: 1 Frame: 1 Memory @ Before Net: 3.943                                                                                           
Scale: 1 Frame: 1 Memory @ After Net: 4.353                                                                                            
Scale: 1 Frame: 1 Memory @ Loss: 4.353                                                                                                 
----------------------------------------------------------------------                                                                 
Scale: 2 Frame: -1 Memory @ Before Net: 4.353                                                                                          
Scale: 2 Frame: -1 Memory @ After Net: 4.456                                                                                           
Scale: 2 Frame: -1 Memory @ Loss: 4.456                                                                                                
Scale: 2 Frame: 1 Memory @ Before Net: 4.456                                                                                           
Scale: 2 Frame: 1 Memory @ After Net: 4.559                                                                                            
Scale: 2 Frame: 1 Memory @ Loss: 4.559                                                                                                 
----------------------------------------------------------------------                                                                 
Scale: 3 Frame: -1 Memory @ Before Net: 4.559                                                                                          
Scale: 3 Frame: -1 Memory @ After Net: 4.584
Scale: 3 Frame: -1 Memory @ Loss: 4.584
Scale: 3 Frame: 1 Memory @ Before Net: 4.584
Scale: 3 Frame: 1 Memory @ After Net: 4.61
Scale: 3 Frame: 1 Memory @ Loss: 4.61
----------------------------------------------------------------------
Success
----------------------------------------------------------------------
Using padding mode: replicate
----------------------------------------------------------------------
Process Started Memory: 0.0
----------------------------------------------------------------------
Inputs Created Memory: 0.251
----------------------------------------------------------------------
Scale: 0 Frame: -1 Memory @ Before Net: 0.251
Scale: 0 Frame: -1 Memory @ After Net: 3.716
Scale: 0 Frame: -1 Memory @ Loss: 3.716
Scale: 0 Frame: 1 Memory @ Before Net: 3.716
Scale: 0 Frame: 1 Memory @ After Net: 7.179
Scale: 0 Frame: 1 Memory @ Loss: 7.179
----------------------------------------------------------------------
Scale: 1 Frame: -1 Memory @ Before Net: 7.179
Scale: 1 Frame: -1 Memory @ After Net: 8.064
Scale: 1 Frame: -1 Memory @ Loss: 8.064
Scale: 1 Frame: 1 Memory @ Before Net: 8.064
Scale: 1 Frame: 1 Memory @ After Net: 8.949
Scale: 1 Frame: 1 Memory @ Loss: 8.949
----------------------------------------------------------------------
Scale: 2 Frame: -1 Memory @ Before Net: 8.949
Scale: 2 Frame: -1 Memory @ After Net: 9.18
Scale: 2 Frame: -1 Memory @ Loss: 9.18
Scale: 2 Frame: 1 Memory @ Before Net: 9.18
Scale: 2 Frame: 1 Memory @ After Net: 9.411
Scale: 2 Frame: 1 Memory @ Loss: 9.411
----------------------------------------------------------------------
Scale: 3 Frame: -1 Memory @ Before Net: 9.411
Scale: 3 Frame: -1 Memory @ After Net: 9.475
Scale: 3 Frame: -1 Memory @ Loss: 9.475
Scale: 3 Frame: 1 Memory @ Before Net: 9.475
Scale: 3 Frame: 1 Memory @ After Net: 9.538
Scale: 3 Frame: 1 Memory @ Loss: 9.538
----------------------------------------------------------------------
Success

Hi, was anyone able to replicate this?