Hi there,
I’m building a model on Kaggle with 2D ResNet50 encoder and 3D U-Net decoder for medical image segmentation, but OutOfMemoryError: CUDA out of memory error keeps hitting. I’ve tried multiple ways I found online but nothing worked for me. Any ideas on how to solve the problem?
This is my net:
os.environ['TF_FORCE_GPU_ALLOW_GROWTH']='true'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:1024"
def free_gpu_cache():
print("Initial GPU Usage")
gpu_usage()
torch.cuda.empty_cache()
cuda.select_device(0)
cuda.close()
cuda.select_device(0)
print("GPU Usage after emptying the cache")
gpu_usage()
class ResNetEncoder(nn.Module):
def __init__(self, pretrained=True):
super(ResNetEncoder, self).__init__()
self.resnet50 = resnet50(pretrained=pretrained)
self.resnet50.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
def forward(self, x):
x = self.resnet50.conv1(x)
x = self.resnet50.bn1(x)
x = self.resnet50.relu(x)
x = self.resnet50.maxpool(x)
enc1 = checkpoint.checkpoint(self.resnet50.layer1, x)
enc2 = checkpoint.checkpoint(self.resnet50.layer2, enc1)
enc3 = checkpoint.checkpoint(self.resnet50.layer3, enc2)
enc4 = checkpoint.checkpoint(self.resnet50.layer4, enc3)
return enc4, enc3, enc2, enc1
class UNet3DDecoder(nn.Module):
def __init__(self):
super(UNet3DDecoder, self).__init__()
self.conv_reduce4 = nn.Conv3d(3072, 2048, kernel_size=1) # Add this layer to reduce channels
self.decoder4 = nn.Sequential(
nn.ConvTranspose3d(2048, 1024, kernel_size=2, stride=2),
nn.Conv3d(1024, 1024, kernel_size=3, padding=1),
nn.BatchNorm3d(1024),
nn.ReLU(inplace=True)
)
self.conv_reduce3 = nn.Conv3d(1536, 1024, kernel_size=1) # Adjust channels here as well
self.decoder3 = nn.Sequential(
nn.ConvTranspose3d(1024, 512, kernel_size=2, stride=2),
nn.Conv3d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm3d(512),
nn.ReLU(inplace=True)
)
self.conv_reduce2 = nn.Conv3d(768, 512, kernel_size=1) # Adjust channels here as well
self.decoder2 = nn.Sequential(
nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2),
nn.Conv3d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm3d(256),
nn.ReLU(inplace=True)
)
self.decoder1 = nn.Sequential(
nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2),
nn.Conv3d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm3d(128),
nn.ReLU(inplace=True)
)
self.vessel_out = nn.Conv3d(128, 1, kernel_size=1)
self.kidney_out = nn.Conv3d(128, 1, kernel_size=1)
def forward(self, enc4, enc3, enc2, enc1):
#print("Shape of enc4 before interpolation:", enc4.shape)
#print("Shape of enc3:", enc3.shape)
interpolated_enc4 = F.interpolate(enc4, size=enc3.shape[2:], mode='trilinear', align_corners=True)
concatenated = torch.cat([interpolated_enc4, enc3], dim=1)
concatenated = self.conv_reduce4(concatenated) # Reduce channels here
dec4 = self.decoder4(concatenated)
#print("Shape of dec4:", dec4.shape)
#print("Shape of enc2:", enc2.shape)
interpolated_dec4 = F.interpolate(dec4, size=enc2.shape[2:], mode='trilinear', align_corners=True)
concatenated = torch.cat([interpolated_dec4, enc2], dim=1)
concatenated = self.conv_reduce3(concatenated) # Reduce channels here
dec3 = self.decoder3(concatenated)
#print("Shape of dec3:", dec3.shape)
#print("Shape of enc1:", enc1.shape)
interpolated_dec3 = F.interpolate(dec3, size=enc1.shape[2:], mode='trilinear', align_corners=True)
concatenated = torch.cat([interpolated_dec3, enc1], dim=1)
concatenated = self.conv_reduce2(concatenated) # Reduce channels here
dec2 = self.decoder2(concatenated)
#print("Shape of dec2:", dec2.shape)
interpolated_dec2 = F.interpolate(dec2, scale_factor=2, mode='trilinear', align_corners=True)
dec1 = self.decoder1(interpolated_dec2)
#print("dec1 = interpolated dec2 =", dec1.shape)
vessel = self.vessel_out(dec1)
kidney = self.kidney_out(dec1)
return vessel, kidney
''''def forward(self, enc4, enc3, enc2, enc1):
dec4 = self.decoder4(torch.cat([F.interpolate(enc4, size=enc3.shape[2:], mode='trilinear', align_corners=True), enc3], dim=1))
dec3 = self.decoder3(torch.cat([F.interpolate(dec4, size=enc2.shape[2:], mode='trilinear', align_corners=True), enc2], dim=1))
dec2 = self.decoder2(torch.cat([F.interpolate(dec3, size=enc1.shape[2:], mode='trilinear', align_corners=True), enc1], dim=1))
dec1 = self.decoder1(F.interpolate(dec2, scale_factor=2, mode='trilinear', align_corners=True))
vessel = self.vessel_out(dec1)
kidney = self.kidney_out(dec1)
return vessel, kidney'''
class ResNetUNet3D(nn.Module):
def __init__(self):
super(ResNetUNet3D, self).__init__()
self.encoder = ResNetEncoder()
self.decoder = UNet3DDecoder()
self.saver = IntermediateSaver() # Assuming IntermediateSaver is defined elsewhere
def forward(self, x):
B, C, D, H, W = x.shape
x = x.view(B * D, C, H, W)
# Save intermediate feature maps to reduce memory usage
enc4, enc3, enc2, enc1 = self.encoder(x)
enc4_path = self.saver.save(enc4)
enc3_path = self.saver.save(enc3)
enc2_path = self.saver.save(enc2)
enc1_path = self.saver.save(enc1)
enc4 = self.saver.load(enc4_path)
enc3 = self.saver.load(enc3_path)
enc2 = self.saver.load(enc2_path)
enc1 = self.saver.load(enc1_path)
_, C_enc4, H_enc4, W_enc4 = enc4.shape
_, C_enc3, H_enc3, W_enc3 = enc3.shape
_, C_enc2, H_enc2, W_enc2 = enc2.shape
_, C_enc1, H_enc1, W_enc1 = enc1.shape
enc4 = enc4.view(B, D, C_enc4, H_enc4, W_enc4).permute(0, 2, 1, 3, 4)
enc3 = enc3.view(B, D, C_enc3, H_enc3, W_enc3).permute(0, 2, 1, 3, 4)
enc2 = enc2.view(B, D, C_enc2, H_enc2, W_enc2).permute(0, 2, 1, 3, 4)
enc1 = enc1.view(B, D, C_enc1, H_enc1, W_enc1).permute(0, 2, 1, 3, 4)
vessel, kidney = self.decoder(enc4, enc3, enc2, enc1)
return vessel, kidney
Train function:
def train_3d(net, train_loader, criterion, optimizer, start_epoch, total_epochs, checkpoint_dir):
scaler = GradScaler()
net.train()
for epoch in range(start_epoch, total_epochs):
running_loss = 0.0
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = inputs.cuda(), labels.cuda()
optimizer.zero_grad()
with autocast():
vessel_pred, kidney_pred = net(inputs)
loss_vessel = criterion(vessel_pred, labels)
loss_kidney = criterion(kidney_pred, labels)
loss = loss_vessel + loss_kidney
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item()
if i % 10 == 9:
print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 10:.3f}")
running_loss = 0.0
del inputs, labels
gc.collect()
torch.cuda.empty_cache() # Clear cached memory
free_gpu_cache()
save_checkpoint(epoch + 1, net, optimizer, loss.item(), os.path.join(checkpoint_dir, 'latest_checkpoint.pth'))
print(f"Checkpoint saved at epoch {epoch + 1}")
train_dataset = MyDataset(train_meta, volume_size=(8, 256, 256))
#test_dataset = MyDataset(test_meta, volume_size=(16, 512, 512))
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
Thanks!