Assuming the tensors are loaded as [channels, height, width]
, you could probably use this lambda transformation:
trans = transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x)
x = torch.randn(3, 224, 224)
out = trans(x)
print(out.shape)
> torch.Size([3, 224, 224])
x = torch.randn(1, 224, 224)
out = trans(x)
print(out.shape)
> torch.Size([3, 224, 224])
If you are loading the images via PIL.Image.open
inside your custom Dataset
, you could also convert them directly to RGB
via PIL.Image.open(...).convert('RGB')
.
However, since you are using ToPILImage
as a transformation, I assume you are loading tensors directly.