I am building an Image Caption Generator using the Flickr30k dataset on Kaggle, and am getting this message.
Here are the Dataset and DataLoader classes. Can someone help me how to fix this ?
class ImageCaptionDataset(Dataset):
def __init__(self, data, word_dict, transform=None):
self.data = data
self.word_dict = word_dict
if transform != None:
self.transform = transform
else:
self.transform = transforms.Compose([transforms.ToTensor()])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_address = os.path.join('../input/flickr-image-dataset/flickr30k_images/flickr30k_images/', self.data['image_name'][idx])
img = Image.open(img_address)
img = self.transform(img)
img = img.numpy().astype('float32')
answer = str(self.data[' comment'][idx]).translate(str.maketrans('', '', string.punctuation)).split()
label = [self.word_dict[word.lower()] for word in answer]
return img, label
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
tempdb = ImageCaptionDataset(full_data, word_dict, transform)
print('Dataset Size: ',len(tempdb))
imageCaptionLoader = DataLoader(tempdb, batch_size = 16, shuffle = True)
print('Loader Size: ', len(imageCaptionLoader))
next(iter(imageCaptionLoader))
Here is the stack trace:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-7-afc3a3f8cb3d> in <module>
12 print('Loader Size: ', len(imageCaptionLoader))
13
---> 14 next(iter(imageCaptionLoader))
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
361
362 def __next__(self):
--> 363 data = self._next_data()
364 self._num_yielded += 1
365 if self._dataset_kind == _DatasetKind.Iterable and \
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
401 def _next_data(self):
402 index = self._next_index() # may raise StopIteration
--> 403 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
404 if self._pin_memory:
405 data = _utils.pin_memory.pin_memory(data)
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
45 else:
46 data = self.dataset[possibly_batched_index]
---> 47 return self.collate_fn(data)
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
82 raise RuntimeError('each element in list of batch should be of equal size')
83 transposed = zip(*batch)
---> 84 return [default_collate(samples) for samples in transposed]
85
86 raise TypeError(default_collate_err_msg_format.format(elem_type))
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py in <listcomp>(.0)
82 raise RuntimeError('each element in list of batch should be of equal size')
83 transposed = zip(*batch)
---> 84 return [default_collate(samples) for samples in transposed]
85
86 raise TypeError(default_collate_err_msg_format.format(elem_type))
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
80 elem_size = len(next(it))
81 if not all(len(elem) == elem_size for elem in it):
---> 82 raise RuntimeError('each element in list of batch should be of equal size')
83 transposed = zip(*batch)
84 return [default_collate(samples) for samples in transposed]
RuntimeError: each element in list of batch should be of equal size