I wanted to transform images and masks. but when I tried to visualize them some of the images and masks don’t have the same flips or orientations.
class UNetDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(image_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
image_path = os.path.join(self.image_dir, self.images[index])
mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
image = np.array(Image.open(image_path).convert("RGB"))
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
mask[mask == 255.0] = 1.0
if self.transform is not None:
new_image = self.transform(image)
new_mask = self.transform(mask)
return new_image, new_mask
train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor()
]
)
for example like this
This is what I use to visualize
def show_random_images(dataset, n):
fig, ax = plt.subplots(n, 2, figsize=(10, 5*n))
random_numbers = random.sample(range(0, len(dataset)), n)
for i in range(n):
random_number = random_numbers[i]
img = dataset[random_number][0].permute(1,2,0)
mask = dataset[random_number][1].permute(1,2,0)
ax[i,0].imshow(img)
ax[i,0].set_axis_off()
ax[i,1].imshow(mask)
ax[i,1].set_axis_off()
plt.show()