I am implementing a custom dataset class. If I set batchsize>=1, the code doesnβt work, please help me.
import glob
import os
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
class ImageDataset(Dataset):
def __init__(self, root, transforms_=None):
self.transform = transforms.Compose(transforms_)
self.files_A = sorted(glob.glob(os.path.join(root, 'train/A') + '/*.*'))
self.files_B = sorted(glob.glob(os.path.join(root, 'train/B') + '/*.*'))
def __getitem__(self, index):
item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))
return {'A': item_A, 'B': item_B}
def __len__(self):
return max(len(self.files_A), len(self.files_B))
transforms_ = [transforms.ToTensor()]
dataloader = DataLoader(ImageDataset('datasets/horse2zebra/', transforms_=transforms_),
batch_size=2,
shuffle=True,
num_workers=0,
drop_last=True)
for i, batch in enumerate(dataloader):
print('it works.')
Traceback (most recent call last):
File "C:/N/tt.py", line 34, in <module>
for i, batch in enumerate(dataloader):
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 615, in __next__
batch = self.collate_fn([self.dataset[i] for i in indices])
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 229, in default_collate
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 229, in <dictcomp>
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 209, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 1 and 3 in dimension 1 at c:\a\w\1\s\tmp_conda_3.6_091443\conda\conda-bld\pytorch_1544087948354\work\aten\src\th\generic/THTensorMoreMath.cpp:1333
Process finished with exit code 1
Directory structure:
.
βββ datasets
| βββ <dataset_name> # i.e. brucewayne2batman
| | βββ train # Training
| | | βββ A # Contains domain A images (i.e. Bruce Wayne)
| | | βββ B # Contains domain B images (i.e. Batman)
| | βββ test # Testing
| | | βββ A # Contains domain A images (i.e. Bruce Wayne)
| | | βββ B # Contains domain B images (i.e. Batman)