Hello, I have this dataset, and I make a data loader, it returns image, mask, and label. but when I make use it, it occurs this issue.
Traceback (most recent call last):
File "test/test_dataset.py", line 16, in test_dataset
for i, j in enumerate(dataload):
File "/home/ubuntu/fzh/.conda/envs/rando/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 582, in __next__
return self._process_next_batch(batch)
File "/home/ubuntu/fzh/.conda/envs/rando/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 608, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
RuntimeError: Traceback (most recent call last):
File "/home/ubuntu/fzh/.conda/envs/rando/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/ubuntu/fzh/.conda/envs/rando/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 68, in default_collate
return [default_collate(samples) for samples in transposed]
File "/home/ubuntu/fzh/.conda/envs/rando/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 68, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/home/ubuntu/fzh/.conda/envs/rando/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 52, in default_collate
return default_collate([torch.from_numpy(b) for b in batch])
File "/home/ubuntu/fzh/.conda/envs/rando/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 43, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: Expected object of scalar type Byte but got scalar type Short for sequence element 1 in sequence argument at position #1 'tensors'
and here is what I have done briefly:
def __getitem__(self, idx):
im = []
for i in range(self.num_input):
direct, _ = self.root_dir[self.num_input * idx + i].split("\n")
if i < self.num_input - 1:
image = nib.load(direct).get_data()
image = np.expand_dims(image, axis=0)
im.append(image)
if i == 0 :
direct = os.path.split(direct)[0] + "/mask"
mask = nib.load(direct + "/mask.nii.gz").get_data()
else:
labels =nib.load(direct).get_data()
labels = np.asarray(labels)
iamges = np.concatenate(im, axis=0).astype(float)
# iamges shape: 4 X H X W X D
# labels shape : HXWXD
# mask shape : HxWxD
images = np.transpose(iamges,(0,3,1,2))
labels = np.transpose(labels,(2,0,1))
mask = np.transpose(mask,(2,0,1))
return images,labels,mask
Thank you for your help