CUDA out of memory when changing first conv layer of resnet


I am trying to change the first layer in resnet-50-fpn backbone inside Mask R-CNN provided in torchvision library. I changed it to 4 channels for RGB-depth input data while keeping the weights for first 3 channels using the following code:

self.backbone = resnet_fpn_backbone(arch, pretrained)
# change first layer to 4 channel for early fusion with 1 channel depth, load pretrained weights on RGB channels
conv1_weight_old = self.backbone.body.conv1.weight
conv1_weight = torch.zeros((64, 4, 7, 7))
conv1_weight[:, 0:3, :, :] = conv1_weight_old
avg_weight = conv1_weight_old.mean(dim=1, keepdim=False)
conv1_weight[:, 3, :, :] = avg_weight
self.backbone.body.conv1.weight = torch.nn.Parameter(conv1_weight)

The problem is very weird: it gives ‘CUDA out of memory error’ after training 1 or 2 epochs and starts a new one. I checked gpu memory is full when I stop the program in debug mode on the line it passes forward the backbone. I made a test to train 10 epochs using 10 images. It sometimes works fine and pass all epochs, but in most cases it gives memory error.

I also tried commenting two lines calculating average but it makes no difference.
The first conv layer in resnet was like this before changing its shape

nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)

I think you need to decrease your batch size. When you add one additional channel, the network size will increase compared to original resnet. You gpu memory is not enough for this batch size.

Actually I don’t think it’s memory size problem. And I found somehow I am able to run it in ‘run’ mode without this bug anymore, but cannot go through it in ‘debug’ mode, which are two modes in PyCharm IDE. It’s weird but it happens. Thanks anyway!

I have the same problem with pycharm debugging too. I try to debug my code without parallelization and with just fewer batch size or and data. It consumes too much memory in debug I think. I also tried vscode. It is lighter and faster for such problems.

Oh I can try it! I tried it before but found it could not go inside functions related to neural networks.