Hi! I’m following the code provided in https://github.com/yunlongdong/FCN-pytorch-easiest to train an FCN, and in each epoch, I plot each mask and the prediction of the model to see it’s progress. However, I’ve noticed it’s just showing me one image, and it’s corresponding mask. Even seeing the test results, only that one image is loaded. Is there anything wrong with my dataloader?
class MyDataset(Dataset):
def __init__(self, root_dir, image_dir, mask_dir, label, transform):
self.dataset_path = root_dir
self.image_dir = image_dir
self.mask_dir = mask_dir
mask_full_path = os.path.join(self.dataset_path, self.mask_dir)
self.mask_file_list = [f for f in listdir(root_dir+mask_dir) if isfile(root_dir+join(mask_dir, f))]
random.shuffle(self.mask_file_list)
self.mapping = {
0.007782101167315175: 0,
0.5019455252918288: 0,
0.9961089494163424: 1,
1.0: 1}
self.transform = transform
def __len__(self):
return len(self.mask_file_list)
def mask_to_class(self, mask):
for k in self.mapping:
mask[mask == k] = self.mapping[k]
return mask
def __getitem__(self, idx):
file_name = self.mask_file_list[index]
img_name = os.path.join(self.dataset_path, self.image_dir, file_name.replace('.png', '.dcm'))
mask_name = os.path.join(self.dataset_path, self.mask_dir, file_name)
imgA = dicom_to_numpy(img_name)
mask_p = img_to_numpy(mask_name)
mask_p = cv2.resize(mask_p, (256, 256))
imgB = self.mask_to_class(mask_p)
imgA = cv2.resize(imgA, (256, 256))
imgB = imgB.astype('uint8')
imgB = onehot(imgB, 2)
imgB = imgB.swapaxes(0, 2).swapaxes(1, 2)
imgB = torch.FloatTensor(imgB)
if self.transform:
imgA = self.transform(imgA)
item = {'A':imgA, 'B':imgB}
return item
image_dir = 'mass/'
mask_dir = 'mask/'
label = 'mass'
fulldir = '/'
full_data = MyDataset(fulldir, image_dir, mask_dir, label, transform=transform)
train_size = int(0.8 * len(full_data))
test_size = len(full_data) - train_size
train_data, test_data = torch.utils.data.random_split(full_data, [train_size, test_size])
train_loader = DataLoader(train_data, batch_size=4, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=4, shuffle=True, num_workers=4)
Thanks for your help!