Dataset transform error loading int16 when using DataLoader


#1

Hi everyone,

I am trying to load a 3D dataset using both the Dataset class and the DataLoader. My image data is an ndarray in int16 and loading it works using just the Dataset but breaks when using the DataLoader with the following error:

RuntimeError: can’t convert a given np.ndarray to a tensor - it has an invalid type. The only supported types are: double, float, int64, int32, and uint8.

Now I am wondering why the Dataset accepts loading from a numpy int16 ndarray but the DataLoader does not.

My ToTensor class looks like this:

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # Expand with channel axis
        # numpy image: H x W x Z
        # torch image: C x H x W x Z
        #print(image.dtype)
        image = torch.from_numpy(image).unsqueeze(0)
        label = torch.ByteTensor(label)

        print(image.size(), label.size())
        return {'image': image,
                'label': label}

ppmi_dataset = PPMIDataset(excel_file=‘PPMI_main_table.xlsx’,
root_dir=FOLDER, transform=ToTensor())

print(len(ppmi_dataset))
a = ppmi_dataset[0][‘image’].numpy()
print(a.dtype)
plt.imshow(a.squeeze()[:,:,45], cmap=‘gray’)
plt.show()

This works correctly and gives as an output (together with the image):

606
torch.Size([1, 105, 127, 105]) torch.Size([1])
int16

Using it with the DataLoader

dataloader = DataLoader(ppmi_dataset, batch_size=2, shuffle=True)

for i, batch in enumerate(dataloader):
    if i == 2:
        plt.figure()
        for idx in range(1):
            plt.subplot(1, 2, 1 + idx)
            plt.imshow(batch[idx]['image'].squeeze().numpy()[:,:,45])
        plt.show()
        break

This results in the error from above. Converting the ndarray to int32 in the ToTensor class works in both cases. Is this some intended feature that I don’t see?


(Simon Wang) #2

I was wrong. See @derEitel 's post below.


#3

What about ShortTensor?
http://pytorch.org/docs/master/tensors.html#torch.ShortTensor

Btw: I found that the problem occurred because some of my labels were loaded as uint16 and some as int16. The former type is not supported and breaks the from_numpy() function.