lepoeme20
(seungwan seo)
February 7, 2020, 9:05am
1
Hello,
can I get both RGB and grayscale image simultaneously in a image dataloader?
Just like blow:
for step, (rgb, gray) in enumerate(dataloader)
Or is there a way to convert rgb images into grayscale within each batch?
for step, (imgs, labels) in enumerate(dataloader):
grayscale = ANY_FUNCTION(imgs).to(device)
rgb = imgs.to(device)
I need both RGB and grayscale images in each mini-batch.
Hello,
I suppose that you have a custom dataset class that inherits torch.data.utils.Dataset
. In your __getitem__
function, simply return both rgb and gray image. Here is an example:
class CustomDatasetRGB(torch.utils.data.Dataset):
def __init__(self, images, labels):
self.images = images
self.labels = labels
def __getitem__(self, index):
rgb_image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
gray_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2GRAY)
# cv2.imshow('gray', gray_image)
# cv2.imshow('rgb', rgb_image)
# cv2.waitKey(0)
rgb_image = np.transpose(rgb_image, axes=(2, 0, 1))
return torch.from_numpy(rgb_image), torch.from_numpy(gray_image), self.labels[index]
def __len__(self):
return len(self.labels)
# 1000 images in dataset
images = ['im2.png'] * 1000
labels = np.random.randint(low=0, high=10, size=1000)
dataset = CustomDatasetRGB(images, labels)
loader = torch.utils.data.DataLoader(dataset, batch_size=4)
for step, (rgb, gray, labels) in enumerate(loader):
print(rgb.size())
print(gray.size())
print(labels.size())
'''
torch.Size([4, 3, 375, 450])
torch.Size([4, 375, 450])
torch.Size([4])
'''