Hello I got this error:
my code is
# 6. Prepare dataset
class dataset(Dataset):
def __init__(self, data_dir, image_fns):
self.data_dir = data_dir
self.image_fns = image_fns
def __len__(self):
return len(self.image_fns)
def __getitem__(self, index):
image_fn = self.image_fns[index]
image_fp = os.path.join(self.data_dir, image_fn)
image = Image.open(image_fp).convert('RGB')
image = self.transform(image)
text = image_fn.split(".")[0]
return image, text
def transform(self, image):
transform_ops = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
return transform_ops(image)
cpu_count = mp.cpu_count()
print(cpu_count)
trainset = dataset(data_path, image_fns_train)
testset = dataset(data_path, image_fns_test)
train_loader = DataLoader(trainset, batch_size=batch_size, num_workers=cpu_count, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, num_workers=cpu_count, shuffle=False)
print(len(train_loader), len(test_loader))
image_batch, text_batch = iter(train_loader).next()
print(image_batch.size(), text_batch)
And I got this error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-67-521157e8a64b> in <module>()
1
----> 2 image_batch, text_batch = iter(train_loader).next()
3 print(image_batch.size(), text_batch)
3 frames
/usr/local/lib/python3.7/dist-packages/torch/_utils.py in reraise(self)
423 # have message field
424 raise self.exc_type(message=msg)
--> 425 raise self.exc_type(msg)
426
427
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
return [default_collate(samples) for samples in transposed]
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py", line 84, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py", line 56, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [3, 25, 20] at entry 0 and [3, 27, 23] at entry 1
Please help me with this, thank you so much