Hi, guys. I’m currently a first year Ph.D student in robotics. In my current project I need to train an image based diffusion policy. During training, I find very high VRAM usage. Following is some training details:
Param size: 90m
Precision: float32
Batch size: 512
Image shape: (3, 224, 224) already on ‘cuda’ device
I tried to debug the memory usage by stepping through the training loop in a debugger. The weird thing I find is that, after passing through a few torchvision.transforms, the memory usage increased by ~5GB. The following is my transforms.
transform_train = transforms.Compose(
[
transforms.Resize([240, 320], antialias=True),
transforms.ColorJitter(
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3
),
transforms.GaussianBlur(kernel_size=5, sigma=(0.01, 2.0)),
transforms.CenterCrop((input_size[0], input_size[1] - 2 * margin)),
]
a = torch.randn([512, 3, 224,224]).to('cuda')
b = transform_train(a)
I’m assuming the transforms does not require gradient, so there shouldn’t be such a high memory usage. What am I missing here?