I am using the following code to reconstruct 3D volume from the 2D segmented slices- which I am getting from my 2D model. I donot want to calculate loss here. My batch size is 4, the slice dimension that is being passed to 2D model is [1,1,256,256] (one image at a time). However, given the slice dimension, I am still getting the cuda out of memory error. torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.70 GiB total capacity; 5.36 GiB already allocated; 13.00 MiB free; 5.37 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
for epoch in range(1, num_epochs + 1):
model.train()
train_dice_total = 0.0
num_steps = 0
for i, batch in enumerate(train_dice_loader):
print(len(batch))
input_samples, gt_samples, voxel_dim = batch
input_samples = input_samples.float()
if torch.cuda.is_available():
input_samples = input_samples.cuda(device= "cuda")
var_gt = gt_samples.cuda(device = "cuda")
model = model.cuda()
# Initialize an empty tensor to store the segmented volume
segmented_volume = torch.zeros((input_samples.shape))
# Iterate over each image in the batch
for img_id in range(input_samples.shape[0]):
# Get the slices for the current image
img_slices = input_samples[img_id]
# Initialize an empty tensor to store the segmented slices for the current image
segmented_img_slices = torch.zeros((img_slices.shape))
# Iterate over each slice in the current image
for slice_id in range(img_slices.shape[0]):
# Get the current slice
slice = img_slices[slice_id]
# Add a batch dimension to the current slice
slice = slice.unsqueeze(0)
print("slice", slice.shape)
# Pass the current slice through the model to get the segmentation mask
segmented_slice = model(slice)
# Remove the batch dimension from the segmented slice
segmented_slice = segmented_slice.squeeze(0)
# Add the segmented slice to the list of segmented slices for the current image
segmented_img_slices[slice_id] = segmented_slice
# embed()
# Combine the segmented slices for the current image into a 3D volume
segmented_images_i = segmented_img_slices.permute(1, 0, 2, 3).unsqueeze(0)
# Add the segmented volume for the current image to the list of segmented volumes
segmented_volume[img_id] = segmented_images_i
# Remove the batch dimension from the segmented volume
segmented_volume = segmented_volume.squeeze(1)
How can I resolve this error?