Hello All,
I am trying to implement a U-Net Architecture for Image Segmentation task.
Problem: : Out of Memory Error thrown even before completing a single forward pass for a batch size of 32 and greater. But for a batch size of 5, I am able to train the model without any errors.
Model : Taken from brain-segmentation-pytorch/unet.py at master · mateuszbuda/brain-segmentation-pytorch · GitHub
def _batch_trainer(net,lossFn,optimFn,batch_size=64,val_percentage = 0.1,device='cpu'):
net.train()
val_len = int(len(Caravan_dataSet)*val_percentage)
train_len = len(Caravan_dataSet) - val_len
train,val = random_split(Caravan_dataSet,[train_len,val_len])
train_iterable = DataLoader(train,batch_size=batch_size,shuffle=True,drop_last=True)
val_iterable = DataLoader(val,batch_size=batch_size,shuffle=True,drop_last=True)
train_cum_loss,val_cum_loss = 0,0
print("Iterator generated.......")
for data in tqdm(train_iterable):
imgData,maskData = data[0].detach().to(device),data[1].detach().to(device)
print(imgData)
#print(torch.cuda.memory_stats(device=torch.cuda.current_device()))
optimFn.zero_grad()
#with torch.set_grad_enabled():
#print(torch.cuda.memory_stats(device=torch.cuda.current_device()))
#torch.cuda.empty_cache()
pred_masks = net(imgData)
#except:
#print(torch.cuda.memory_stats(device=torch.cuda.current_device()))
# print("done with forward pass")
# print("Pred size : ",pred_masks.size())
loss = lossFn(pred_masks,maskData)
loss.backward()
train_cum_loss+=float(loss)
optimFn.step()
# net.eval()
# for ind,data in enumerate(val_iterable):
# imgData,maskData = data['Images'].to(device),data['Masked Images'].to(device)
# imgData,maskData = imgData.to(device),maskData.to(device)
# pred_masks = net(imgData)
# loss = lossFn(pred_masks,maskData) ## Loss Per Batch
# val_cum_loss+=float(loss)
train_loss_per_epoch = train_cum_loss/len(train_iterable)
#val_loss_per_epoch = val_cum_loss/len(val_iterable)
print("Completed an Epoch")
#return train_loss_per_epoch,val_loss_per_epoch ## Loss per Epoch
return train_loss_per_epoch
Though multiple queries have been raised on the same issue, I was not able to resolve it. I have used loss.item() as well to append the loss values into a list.
Thanks in advance