Convert grayscale images to RGB

Assuming the tensors are loaded as [channels, height, width], you could probably use this lambda transformation:

trans = transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x)

x = torch.randn(3, 224, 224)
out = trans(x)
print(out.shape)
> torch.Size([3, 224, 224])

x = torch.randn(1, 224, 224)
out = trans(x)
print(out.shape)
> torch.Size([3, 224, 224])

If you are loading the images via PIL.Image.open inside your custom Dataset, you could also convert them directly to RGB via PIL.Image.open(...).convert('RGB').
However, since you are using ToPILImage as a transformation, I assume you are loading tensors directly.

1 Like