I am using the following code to reconstruct 3D volume from the 2D segmented slices- which I am getting from my 2D model. I donot want to calculate loss here. My batch size is 4, the slice dimension that is being passed to 2D model is [1,1,256,256] (one image at a time). However, given the slice dimension, I am still getting the cuda out of memory error.
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.70 GiB total capacity; 5.36 GiB already allocated; 13.00 MiB free; 5.37 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
for epoch in range(1, num_epochs + 1): model.train() train_dice_total = 0.0 num_steps = 0 for i, batch in enumerate(train_dice_loader): print(len(batch)) input_samples, gt_samples, voxel_dim = batch input_samples = input_samples.float() if torch.cuda.is_available(): input_samples = input_samples.cuda(device= "cuda") var_gt = gt_samples.cuda(device = "cuda") model = model.cuda() # Initialize an empty tensor to store the segmented volume segmented_volume = torch.zeros((input_samples.shape)) # Iterate over each image in the batch for img_id in range(input_samples.shape): # Get the slices for the current image img_slices = input_samples[img_id] # Initialize an empty tensor to store the segmented slices for the current image segmented_img_slices = torch.zeros((img_slices.shape)) # Iterate over each slice in the current image for slice_id in range(img_slices.shape): # Get the current slice slice = img_slices[slice_id] # Add a batch dimension to the current slice slice = slice.unsqueeze(0) print("slice", slice.shape) # Pass the current slice through the model to get the segmentation mask segmented_slice = model(slice) # Remove the batch dimension from the segmented slice segmented_slice = segmented_slice.squeeze(0) # Add the segmented slice to the list of segmented slices for the current image segmented_img_slices[slice_id] = segmented_slice # embed() # Combine the segmented slices for the current image into a 3D volume segmented_images_i = segmented_img_slices.permute(1, 0, 2, 3).unsqueeze(0) # Add the segmented volume for the current image to the list of segmented volumes segmented_volume[img_id] = segmented_images_i # Remove the batch dimension from the segmented volume segmented_volume = segmented_volume.squeeze(1)
How can I resolve this error?