GPU memory spike with torchvision.transforms

Hi, guys. I’m currently a first year Ph.D student in robotics. In my current project I need to train an image based diffusion policy. During training, I find very high VRAM usage. Following is some training details:
Param size: 90m
Precision: float32
Batch size: 512
Image shape: (3, 224, 224) already on ‘cuda’ device

I tried to debug the memory usage by stepping through the training loop in a debugger. The weird thing I find is that, after passing through a few torchvision.transforms, the memory usage increased by ~5GB. The following is my transforms.

transform_train = transforms.Compose(
           [
               transforms.Resize([240, 320], antialias=True),
               transforms.ColorJitter(
                   brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3
               ),
               transforms.GaussianBlur(kernel_size=5, sigma=(0.01, 2.0)),
               transforms.CenterCrop((input_size[0], input_size[1] - 2 * margin)),
]
a = torch.randn([512, 3, 224,224]).to('cuda')
b = transform_train(a)

I’m assuming the transforms does not require gradient, so there shouldn’t be such a high memory usage. What am I missing here?

Tensor a, of shape 512, 3, 224,224 holds 117964800 values in 32 bits, meaning moving this tensor to GPU already takes around ~472 MB of VRAM.

As far as I know, each step of transforms may require additional GPU VRAM for calculations/caching, which after adding up, make it possible to take few GB of VRAM.

you can add at the end of the code that you provided

print(torch.cuda.memory_summary(device='cuda'))

and pay attention to Allocated memory and GPU reserved memory. In my case, the GPU reserved memory shows few times higher value then allocated, meaning CUDA has reserved a lot of additional VRAM for future potential use.