Trouble with ToPILImage

Why does this not work?

import torchvision.transforms.functional as tf
from torchvision import transforms
pic = np.random.randint(0, 255+1, size=28*28).reshape(28, 28)
pic = pic.astype(int)
plt.imshow(pic)
t = transforms.ToPILImage()
t(pic.reshape(28, 28, 1))
# tf.to_pil_image(pic.reshape(28, 28, 1))

A beautiful random picture is plotted by matplotlib, but no matter what datatype I chose for my NumPy ndarray, neither to_pil_image or ToPILImage work as expected.

The docs have this to say:

Converts a tensor … or a numpy ndarray of shape H x W x C to a PIL Image while preserving the value range.

If the input has 1 channel, the mode is determined by the data type (i.e int , float , short ).

None of these datatypes work except for “short”.

Everything else results in:

TypeError: Input type int64/float64 is not supported

thrown from torchvision/transforms/functional.py in to_pil_image().

Further, even though the short datatype will work for the stand alone code snippet I provided first, it breaks down when used inside a transform.Compose() called from a Dataset object’s __getitem__:

choices = transforms.RandomChoice([transforms.RandomAffine(30),
                                   transforms.RandomPerspective()])

transform = transforms.Compose([transforms.ToPILImage(),
                                transforms.RandomApply([choices], 0.5),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

trainset = MNIST('data/train.csv', transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4)


RuntimeError: DataLoader worker (pid 12917) is killed by signal: Floating point exception.
RuntimeError: DataLoader worker (pid(s) 12917) exited unexpectedly

Hi,

Solution:

pic = pic.astype(np.uint8)

Reason:
Python int dtype is 64bits so it does not work on PIL images. If you scroll down the documentation of ToPILImage, you will find a link to all modes that PIL supports. Concepts - Pillow (PIL Fork) 10.2.0.dev0 documentation

In this page, you can find that different dtypes correspond to different modes, so if you have a 0-255 image you need to use 8bit dtypes.

I think based on first issue, second one is also can be considered as solved, right?

Bests

Yessir, that indeed solves my problem.