Running out of GPU memory when doing transfer learning with resnet34

Hi, all.

I am working on semantic segmentation task with my own model and having a “GPU memory run out” issue, and I have no idea why this is happening.
(My gpu is GTX1070 with 8G video memory.)

My model uses Resnet34 from torchvision as an encoder.
I added a u-net based decoder on top of it.
To do that, I extracted output from each layer of Resnet34 following this post

Here is my code.

class DEC_Block(nn.Module):
    def __init__(self, in_channles, out_channels, skip_channels):
        self.tconv = nn.ConvTranspose2d(in_channles, skip_channels, kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(skip_channels*2, out_channels, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
    def forward(self, x, skip):
        x = self.tconv(x)
        while x.shape[-1] != skip.shape[-1]:
            x = x[:,:,:,:-1]
        while x.shape[-2] != skip.shape[-2]:
            x = x[:, :, :-1, :]
        x =, skip), -3)  # cat over the channel axis
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        return x

class Decoder(nn.Module):
    def __init__(self, skips, out_channels):
        self.skips = skips
        self.dec1 = DEC_Block(512, 256, 256)
        self.dec2 = DEC_Block(256, 128, 128)
        self.dec3 = DEC_Block(128, 64, 64)
        self.dec4 = DEC_Block(64, 64, 64)
        self.conv = nn.Conv2d(64, out_channels, kernel_size=1)
    def forward(self, x):
        x = self.dec1(x, self.skips[3])  # in:7, skip:6
        x = self.dec2(x, self.skips[2])  # skip:5
        x = self.dec3(x, self.skips[1])  # skip:4
        x = self.dec4(x, self.skips[0])  # skip:2
        x = self.conv(x)
        return x

class seg_resnet34(nn.Module):
    def __init__(self, out_channels):
        self.Encoder = nn.Sequential(*list(models.resnet34(pretrained=True).children())[:-2])
        self.skips = []
        for enc_params in self.Encoder.parameters():
            enc_params.requires_grad = False
        self.Decoder = Decoder(self.skips, out_channels)
    def get_skip(self, module, input, output):
    def forward(self, x):
        x = self.Encoder(x)
        x = self.Decoder(x)
        return x

class mIoULoss(nn.Module):
    """soft mean iou loss
    def __init__(self):
        self.softmax = nn.Softmax(dim=-3)
        self.smooth = 1e-5
    def forward(self, output, label, weight=None):
        if weight is None:
            weight = 1
            weight = self.softmax(weight)
        output = self.softmax(output)
        inter = torch.sum(output * label, dim=(-1, -2))
        union = torch.sum(output + label, dim=(-1, -2)) - inter
        iou = (inter + self.smooth)/(union + self.smooth)*weight
        return -torch.mean(iou)

class BCEIoULoss(nn.Module):
    def __init__(self):
        self.bce = F.binary_cross_entropy_with_logits
        self.iou = mIoULoss()
    def forward(self, output, label, weight=None):
        bce = self.bce(output, label, weight=weight)
        iou = self.iou(output, label, weight=weight)
        return 0.5 * bce + iou

criterion = BCEIoULoss()
optimizer = optim.SGD(model.parameters(),lr=0.001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

def train_model(SOME_PARAMS)


train_model function is almost the same as the one in Pytorch tutorial

Can anyone help me with solving “GPU running out of memory” issue? Thanks!

How large is your input and what is your batch size?
Could you try to lower the batch size or resize the input to see when the GPU memory would be sufficient?

Hi, ptrblck! Thank you for the reply.

First, to answer your question, the original images were resized to 152 * 242 pix, and I used batch_size = 1, but still got memory shortage problem before the first epoch finishes.

Here is a good news (well good for me). I solved the problem by substitute all layers from ResNet34 to other variables instead of using hook. I thought the cause was that the list in which intermediate outputs were stored, self.skips in class seg_resnet34 never got refreshed. (Also, if self.skips remains un-refreshed, the rest of my code won’t work because it cant’t acquire appropriate intermediate outputs.)

Again, thank you for your help.

