class CasingDataset2(Dataset):
def __init__(self, csv_file: str = "casing_model.csv"):
"""Construct an instance."""
self.len = 10
def __getitem__(
self, index: int
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Get an item."""
return torch.FloatTensor([1, 2, 3]), torch.LongTensor(index)
def __len__(self) -> int:
"""Return data length."""
return self.len
And I try run it with following arguments. It works
train_dataset2 = CasingDataset2()
train_loader2 = DataLoader(train_dataset2, batch_size=1, shuffle=True, num_workers=1, drop_last=False)
for idx, (data, target) in enumerate(train_loader2):
print(idx)
0
1
2
3
4
5
6
7
8
9
Then I try batch_size=2
. It raises an error
train_dataset2 = CasingDataset2()
train_loader2 = DataLoader(train_dataset2, batch_size=2, shuffle=True, num_workers=1, drop_last=False)
for idx, (data, target) in enumerate(train_loader2):
print(idx)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-48-8986358bcacc> in <module>
----> 1 for idx, (data, target) in enumerate(train_loader2):
2 print(idx)
/i/pyenv/versions/py-default/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
515 if self._sampler_iter is None:
516 self._reset()
--> 517 data = self._next_data()
518 self._num_yielded += 1
519 if self._dataset_kind == _DatasetKind.Iterable and \
/i/pyenv/versions/py-default/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
1197 else:
1198 del self._task_info[idx]
-> 1199 return self._process_data(data)
1200
1201 def _try_put_index(self):
/i/pyenv/versions/py-default/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
1223 self._try_put_index()
1224 if isinstance(data, ExceptionWrapper):
-> 1225 data.reraise()
1226 return data
1227
/i/pyenv/versions/py-default/lib/python3.8/site-packages/torch/_utils.py in reraise(self)
427 # have message field
428 raise self.exc_type(message=msg)
--> 429 raise self.exc_type(msg)
430
431
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/i/pyenv/versions/py-default/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
data = fetcher.fetch(index)
File "/i/pyenv/versions/py-default/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/i/pyenv/versions/py-default/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in default_collate
return [default_collate(samples) for samples in transposed]
File "/i/pyenv/versions/py-default/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/i/pyenv/versions/py-default/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [0] at entry 0 and [1] at entry 1
version: 1.8.1+cu102
Why I can’t use another batch_size
?