How to solve OutOfMemoryError: CUDA out of memory

Hi there,

I’m building a model on Kaggle with 2D ResNet50 encoder and 3D U-Net decoder for medical image segmentation, but OutOfMemoryError: CUDA out of memory error keeps hitting. I’ve tried multiple ways I found online but nothing worked for me. Any ideas on how to solve the problem?

This is my net:

os.environ['TF_FORCE_GPU_ALLOW_GROWTH']='true'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:1024"
def free_gpu_cache():
    print("Initial GPU Usage")
    gpu_usage()                             

    torch.cuda.empty_cache()

    cuda.select_device(0)
    cuda.close()
    cuda.select_device(0)

    print("GPU Usage after emptying the cache")
    gpu_usage()
class ResNetEncoder(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNetEncoder, self).__init__()
        self.resnet50 = resnet50(pretrained=pretrained)
        self.resnet50.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    def forward(self, x):
        x = self.resnet50.conv1(x)
        x = self.resnet50.bn1(x)
        x = self.resnet50.relu(x)
        x = self.resnet50.maxpool(x)
        enc1 = checkpoint.checkpoint(self.resnet50.layer1, x)
        enc2 = checkpoint.checkpoint(self.resnet50.layer2, enc1)
        enc3 = checkpoint.checkpoint(self.resnet50.layer3, enc2)
        enc4 = checkpoint.checkpoint(self.resnet50.layer4, enc3)
        
        return enc4, enc3, enc2, enc1


class UNet3DDecoder(nn.Module):
    def __init__(self):
        super(UNet3DDecoder, self).__init__()
        self.conv_reduce4 = nn.Conv3d(3072, 2048, kernel_size=1)  # Add this layer to reduce channels
        self.decoder4 = nn.Sequential(
            nn.ConvTranspose3d(2048, 1024, kernel_size=2, stride=2),
            nn.Conv3d(1024, 1024, kernel_size=3, padding=1),
            nn.BatchNorm3d(1024),
            nn.ReLU(inplace=True)
        )
        
        self.conv_reduce3 = nn.Conv3d(1536, 1024, kernel_size=1)  # Adjust channels here as well
        self.decoder3 = nn.Sequential(
            nn.ConvTranspose3d(1024, 512, kernel_size=2, stride=2),
            nn.Conv3d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm3d(512),
            nn.ReLU(inplace=True)
        )
        
        self.conv_reduce2 = nn.Conv3d(768, 512, kernel_size=1)  # Adjust channels here as well
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2),
            nn.Conv3d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm3d(256),
            nn.ReLU(inplace=True)
        )
        
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2),
            nn.Conv3d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True)
        )
        
        self.vessel_out = nn.Conv3d(128, 1, kernel_size=1)
        self.kidney_out = nn.Conv3d(128, 1, kernel_size=1)

    def forward(self, enc4, enc3, enc2, enc1):
        #print("Shape of enc4 before interpolation:", enc4.shape)
        #print("Shape of enc3:", enc3.shape)
        interpolated_enc4 = F.interpolate(enc4, size=enc3.shape[2:], mode='trilinear', align_corners=True)
        concatenated = torch.cat([interpolated_enc4, enc3], dim=1)
        concatenated = self.conv_reduce4(concatenated)  # Reduce channels here
        dec4 = self.decoder4(concatenated)
        
        #print("Shape of dec4:", dec4.shape)
        #print("Shape of enc2:", enc2.shape)
        interpolated_dec4 = F.interpolate(dec4, size=enc2.shape[2:], mode='trilinear', align_corners=True)
        concatenated = torch.cat([interpolated_dec4, enc2], dim=1)
        concatenated = self.conv_reduce3(concatenated)  # Reduce channels here
        dec3 = self.decoder3(concatenated)
        
        #print("Shape of dec3:", dec3.shape)
        #print("Shape of enc1:", enc1.shape)
        interpolated_dec3 = F.interpolate(dec3, size=enc1.shape[2:], mode='trilinear', align_corners=True)
        concatenated = torch.cat([interpolated_dec3, enc1], dim=1)
        concatenated = self.conv_reduce2(concatenated)  # Reduce channels here
        dec2 = self.decoder2(concatenated)
        
        #print("Shape of dec2:", dec2.shape)
        interpolated_dec2 = F.interpolate(dec2, scale_factor=2, mode='trilinear', align_corners=True)
        dec1 = self.decoder1(interpolated_dec2)
        #print("dec1 = interpolated dec2 =", dec1.shape)
        
        vessel = self.vessel_out(dec1)
        kidney = self.kidney_out(dec1)
        
        return vessel, kidney

        
    ''''def forward(self, enc4, enc3, enc2, enc1):
        dec4 = self.decoder4(torch.cat([F.interpolate(enc4, size=enc3.shape[2:], mode='trilinear', align_corners=True), enc3], dim=1))
        dec3 = self.decoder3(torch.cat([F.interpolate(dec4, size=enc2.shape[2:], mode='trilinear', align_corners=True), enc2], dim=1))
        dec2 = self.decoder2(torch.cat([F.interpolate(dec3, size=enc1.shape[2:], mode='trilinear', align_corners=True), enc1], dim=1))
        dec1 = self.decoder1(F.interpolate(dec2, scale_factor=2, mode='trilinear', align_corners=True))

        vessel = self.vessel_out(dec1)
        kidney = self.kidney_out(dec1)

        return vessel, kidney'''



class ResNetUNet3D(nn.Module):
    def __init__(self):
        super(ResNetUNet3D, self).__init__()
        self.encoder = ResNetEncoder()
        self.decoder = UNet3DDecoder()
        self.saver = IntermediateSaver()  # Assuming IntermediateSaver is defined elsewhere

    def forward(self, x):
        B, C, D, H, W = x.shape
        x = x.view(B * D, C, H, W)

        # Save intermediate feature maps to reduce memory usage
        enc4, enc3, enc2, enc1 = self.encoder(x)

        enc4_path = self.saver.save(enc4)
        enc3_path = self.saver.save(enc3)
        enc2_path = self.saver.save(enc2)
        enc1_path = self.saver.save(enc1)

        enc4 = self.saver.load(enc4_path)
        enc3 = self.saver.load(enc3_path)
        enc2 = self.saver.load(enc2_path)
        enc1 = self.saver.load(enc1_path)

        _, C_enc4, H_enc4, W_enc4 = enc4.shape
        _, C_enc3, H_enc3, W_enc3 = enc3.shape
        _, C_enc2, H_enc2, W_enc2 = enc2.shape
        _, C_enc1, H_enc1, W_enc1 = enc1.shape

        enc4 = enc4.view(B, D, C_enc4, H_enc4, W_enc4).permute(0, 2, 1, 3, 4)
        enc3 = enc3.view(B, D, C_enc3, H_enc3, W_enc3).permute(0, 2, 1, 3, 4)
        enc2 = enc2.view(B, D, C_enc2, H_enc2, W_enc2).permute(0, 2, 1, 3, 4)
        enc1 = enc1.view(B, D, C_enc1, H_enc1, W_enc1).permute(0, 2, 1, 3, 4)

        vessel, kidney = self.decoder(enc4, enc3, enc2, enc1)
        return vessel, kidney

Train function:

def train_3d(net, train_loader, criterion, optimizer, start_epoch, total_epochs, checkpoint_dir):
    scaler = GradScaler()
    net.train()
    for epoch in range(start_epoch, total_epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.cuda(), labels.cuda()

            optimizer.zero_grad()

            with autocast():
                vessel_pred, kidney_pred = net(inputs)
                loss_vessel = criterion(vessel_pred, labels)
                loss_kidney = criterion(kidney_pred, labels)
                loss = loss_vessel + loss_kidney

            scaler.scale(loss).backward()

            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            if i % 10 == 9:
                print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 10:.3f}")
                running_loss = 0.0
        del inputs, labels
        gc.collect()
        torch.cuda.empty_cache()  # Clear cached memory
        free_gpu_cache()

        save_checkpoint(epoch + 1, net, optimizer, loss.item(), os.path.join(checkpoint_dir, 'latest_checkpoint.pth'))
        print(f"Checkpoint saved at epoch {epoch + 1}")
train_dataset = MyDataset(train_meta, volume_size=(8, 256, 256))
#test_dataset = MyDataset(test_meta, volume_size=(16, 512, 512))

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

Thanks!

Just venturing a guess here, but 30GB of VRAM on a kaggle machine is not enough to run Conv3d with input size of 3072. Mine barely runs with input size of 32.

Thanks for replying! Any ideas what other machines could fit?

An easy way to check if it will fit is checking the parameter count of the model. This can be done by doing this:

parameter_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(parameter_count)

If it’s very large (in the Billion Paramters range) then you will most likely need to use Multiple GPUs to train the model. You can use DataParallel for that like this:

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

There are many platforms online that you can find that let you train on multiple GPUs each with upto or more than 80GB VRAM.

1 Like