Hello PyTorch community!
I am trying to train a 3D-Conv based model (summary printed below using torchinfo). My input shape looks like (16, 3, 3, 640, 256).
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Sequential: 1-1 [16, 1, 1, 32, 80] --
Conv3d: 2-1 [16, 64, 3, 128, 320] 5,248
LeakyReLU: 2-2 [16, 64, 3, 128, 320] --
Conv3d: 2-3 [16, 128, 3, 64, 160] 221,184
BatchNorm3d: 2-4 [16, 128, 3, 64, 160] 256
LeakyReLU: 2-5 [16, 128, 3, 64, 160] --
Conv3d: 2-6 [16, 256, 3, 32, 80] 884,736
BatchNorm3d: 2-7 [16, 256, 3, 32, 80] 512
LeakyReLU: 2-8 [16, 256, 3, 32, 80] --
Conv3d: 2-9 [16, 512, 3, 32, 80] 3,538,944
BatchNorm3d: 2-10 [16, 512, 3, 32, 80] 1,024
LeakyReLU: 2-11 [16, 512, 3, 32, 80] --
Conv3d: 2-12 [16, 1, 1, 32, 80] 13,825
==========================================================================================
Total params: 4,665,729
Trainable params: 4,665,729
Non-trainable params: 0
Total mult-adds (G): 663.18
==========================================================================================
Input size (MB): 94.37
Forward/backward pass size (MB): 3523.54
Params size (MB): 18.66
Estimated Total Size (MB): 3636.58
==========================================================================================
When I use zero padding, my model does not face any memory issues, but when I change the padding_mode to replicate I get a CUDA memory error, which seems to arise during padding of the input.
return torch._C._nn.replication_pad3d(input, pad)
RuntimeError: CUDA out of memory. Tried to allocate 656.00 MiB (GPU 0; 31.75 GiB total capacity; 29.27 GiB already allocated; 632.00 MiB free; 29.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
I tried to check the memory allocated to tensors during the forward pass (I make multiple forward passes in a single iteration through the network), and it does seem like with replicate padding it increases significantly more.
All values in GB.
Zero Padding
mem allocated at start of iter 1 0.1781
mem allocated 12.34
mem allocated 15.80
mem allocated 19.19
mem allocated 20.06
mem allocated 20.90
mem allocated 21.12
mem allocated 21.33
mem allocated 21.39
SUCCESSFULLY COMPLETED iter 1
Replicate Padding
mem allocated at start of iter 1 0.1781
mem allocated 12.34
mem allocated 19.27
mem allocated 26.13
mem allocated 27.90
mem allocated 29.66
mem allocated 30.13
CRASHED AT iter 1
Is this an expected behaviour of replicate
padding in 3D Convs or is there anything I can do to improve this?