When we have alpha channel present in images, reading the data with ImageFolder converts it to RGB. I’m a Pytorch noobie. How would one write custom Datalodaer/Dataset to read images with 4 channels?
what is your image file format?
Images are PNG with 32 bit format, 8 bits for each of the four channel.
ImageFolder
uses the pil_loader
by default, if accimage
is not the used backend and if no other loader was specified.
This loader converts the images to RGB
as seen here.
To use the RGBA
channels, you could write a custom loader (basically just copy-paste the pil_loader
and remove the convert
call or convert to RGBA
) and pass it as the loader
argument to ImageFolder
.
This worked. Thanks!
Uploading my solution:
from PIL import Image
from torchvision import datasets
def custom_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGBA')
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x],
loader=custom_loader)
for x in ['train', 'val']}