I have a model similar to 2D U-Net (with added blocks and etc.) implemented on Pytorch for medical MRI images. I have a data set of about 12000 images. On the first epoch, about 5000 images into the network, the outputs begin to simply be blank images for the rest of the training. I believe it is a class imbalance but I am unsure on how I can fix this issue.
Any suggestions for this issue? I have sample outputs before it converges to blank images as well. I have lowered the learning rate to 1e-5 and messed around with batch size with no success.
Could you upload a simplified version your dataloader and model?
Below is the DataLoader. I recently switched to the sagittal view of the images therefore I am padding the images to 1 size. I tried cropping as well but results in same outcome. The model is quite heavy but it is a U-Net using deformable convolution instead with squeeze and excitation blocks.
class DataLoaderSegmentation(data.Dataset):
def __init__(self, file_path):
self.frame_names = []
self.mask_names = []
train_id = next(os.walk(file_path))[1]
for n, id_ in tqdm(enumerate(train_id), total=len(train_id)):
frame_glob = glob.glob(file_path + id_ + '/frame/.*')
mask_glob = glob.glob(file_path + id_ + '/mask/.*')
[os.remove(frame_glob[u]) for u in range(len(frame_glob))]
[os.remove(mask_glob[u]) for u in range(len(mask_glob))]
frame_path = os.listdir(file_path + id_ + '/frame/')
mask_path = os.listdir(file_path + id_ + '/mask/')
frame_path.sort()
mask_path.sort()
for r in range(len(frame_path)):
self.mask_names.append(file_path + id_ + '/mask/' + mask_path[r])
self.frame_names.append(file_path + id_ + '/frame/' + frame_path[r])
def __getitem__(self, index):
# allows us to get index
PNGtoTensor = torchvision.transforms.ToTensor()
Normalize = torchvision.transforms.Normalize([0.5],[0.5])
#Crop = torchvision.transforms.CenterCrop((60,384))
frame_name = self.frame_names[index]
frame = Image.open(frame_name)
#frame = Crop(frame)
frame = ImageOps.pad(frame, (512,120),centering = (0,0))
frame = (PNGtoTensor(frame))
#frame = Normalize(frame)
mask_name = self.mask_names[index]
mask = Image.open(mask_name)
#mask = Crop(mask)
mask = ImageOps.pad(mask, (512,120),centering=(0,0))
mask = (PNGtoTensor(mask))
#mask = Normalize(mask)
return frame, mask
def __len__(self):
# can get the length of dataset len(dataset)
return len(self.frame_names)