Out Of Memory Error - for batch size 32 & greater. Not completing a single forward pass!

Hello All,

I am trying to implement a U-Net Architecture for Image Segmentation task.
Problem: : Out of Memory Error thrown even before completing a single forward pass for a batch size of 32 and greater. But for a batch size of 5, I am able to train the model without any errors.

Model : Taken from brain-segmentation-pytorch/unet.py at master · mateuszbuda/brain-segmentation-pytorch · GitHub

def _batch_trainer(net,lossFn,optimFn,batch_size=64,val_percentage = 0.1,device='cpu'):
  val_len = int(len(Caravan_dataSet)*val_percentage)
  train_len = len(Caravan_dataSet) - val_len
  train,val = random_split(Caravan_dataSet,[train_len,val_len])
  train_iterable = DataLoader(train,batch_size=batch_size,shuffle=True,drop_last=True)
  val_iterable = DataLoader(val,batch_size=batch_size,shuffle=True,drop_last=True)
  train_cum_loss,val_cum_loss = 0,0

  print("Iterator generated.......")
  for data in tqdm(train_iterable):
    imgData,maskData = data[0].detach().to(device),data[1].detach().to(device)
    #with torch.set_grad_enabled():
    pred_masks = net(imgData)

    # print("done with forward pass")
    # print("Pred size : ",pred_masks.size())
    loss = lossFn(pred_masks,maskData)

  # net.eval()
  # for ind,data in enumerate(val_iterable):
  #   imgData,maskData = data['Images'].to(device),data['Masked Images'].to(device)
  #   imgData,maskData = imgData.to(device),maskData.to(device)
  #   pred_masks = net(imgData)
  #   loss = lossFn(pred_masks,maskData) ## Loss Per Batch 
  #   val_cum_loss+=float(loss)

  train_loss_per_epoch = train_cum_loss/len(train_iterable)
  #val_loss_per_epoch = val_cum_loss/len(val_iterable)
  print("Completed an Epoch")
  #return train_loss_per_epoch,val_loss_per_epoch ## Loss per Epoch  
  return train_loss_per_epoch

Though multiple queries have been raised on the same issue, I was not able to resolve it. I have used loss.item() as well to append the loss values into a list.

Thanks in advance

Your GPU doesn’t seem to have enough memory for the used batch size, so you would need to reduce the memory usage e.g. by lowering the batch size or by trading compute for memory via torch.utils.checkpoint.

1 Like