GPU memory spike with torchvision.transforms

Hi, guys. I’m currently a first year Ph.D student in robotics. In my current project I need to train an image based diffusion policy. During training, I find very high VRAM usage. Following is some training details:
Param size: 90m
Precision: float32
Batch size: 512
Image shape: (3, 224, 224) already on ‘cuda’ device

I tried to debug the memory usage by stepping through the training loop in a debugger. The weird thing I find is that, after passing through a few torchvision.transforms, the memory usage increased by ~5GB. The following is my transforms.

transform_train = transforms.Compose(
               transforms.Resize([240, 320], antialias=True),
                   brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3
               transforms.GaussianBlur(kernel_size=5, sigma=(0.01, 2.0)),
               transforms.CenterCrop((input_size[0], input_size[1] - 2 * margin)),
a = torch.randn([512, 3, 224,224]).to('cuda')
b = transform_train(a)

I’m assuming the transforms does not require gradient, so there shouldn’t be such a high memory usage. What am I missing here?

Tensor a, of shape 512, 3, 224,224 holds 117964800 values in 32 bits, meaning moving this tensor to GPU already takes around ~472 MB of VRAM.

As far as I know, each step of transforms may require additional GPU VRAM for calculations/caching, which after adding up, make it possible to take few GB of VRAM.

you can add at the end of the code that you provided


and pay attention to Allocated memory and GPU reserved memory. In my case, the GPU reserved memory shows few times higher value then allocated, meaning CUDA has reserved a lot of additional VRAM for future potential use.