How to use RGBA or 4 channel images in a DCGAN model in pytorch

Hello, many thanks to all contributors in advance.

I am trying to apply pytorch’s DCGAN model however when I try to run the dataset it defaults to 3 channel images even though my images are 4 channel RGBA.

How to use RGBA or 4 channel images in a DCGAN model in pytorch?


dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

This is the code I use to generate the dataset, I would like to modify it to be able to use my images with format [647, 4, 256,256]

You would need to modify the first conv layer of the model which I guess would currently accept tensors with 3 channels.
However, are you sure the alpha channel contains any (useful) information and is not just set to the highest value for all pixels? In the latter case, I would claim you can directly remove this channel as it doesn’t contain any information.

When executing my code without any correction, I get the following error

Given groups=1, weight of size [32, 4, 4, 4], expected input[1, 3, 256, 256] to have 4 channels, but got 3 channels instead 

This is because dataset.dset.ImageFolder is converting the images to 3 channels, how can I make it read all four channels?

The information of the alpha channel is of supreme importance since the images are physical representations of the axes X, Y, Z, Z’ of a physical magnitude.

Thank you very much for taking the time to respond!

You could provide a custom loader to ImageFolder as the default would transform the images to RGB as seen here.

1 Like

Hi,
As you know, we can use only 1 and 3 channels in Keras, etc. But I just realized that
we can use the desired channels in Pytorch based on this link:

“You can write a custom dataset class for reading images and make them 16 channels input.
Check [link] for writing a custom dataset.”
But according to what you wrote here, finally “ImageFolder” converts the input to RGB channels?
def pil_loader(path: str) → Image.Image:
with open(path, “rb”) as f:
img = Image.open(f)
return img.convert(“RGB”)
Is that correct?

A custom Dataset allows you to load and process the data the way you want. It would come down to loading an image with 16 channels as this it not a standard image format. However, if you have stored these samples as binary data you could just load them.

Many thanks for answering.
I have another question. I could load them and the conv layer can deal with these images? or this binary image converts to an RGB format?
In Keras, we have to use 1 or 3 channels. There is not this constraint for Pytorch? We can load images with any formats and with any number of channels?

You are creating the conv layer and as long as you specify in_channels=16 in the first conv layer it will be able to process inputs with 16 channels.

Again, you can create a custom Dataset and load whatever you want. Samples with 16 channels are not images, so PIL or OpenCV won’t be able to load these samples. However, as long as you can load the sample somehow inside the Dataset you will be able to use it also to train your model.

1 Like

Oh. Finally, I could understand your meaning! The main challenge is the load of the images! If we can load them, we can train the networks with them! Is that correct?
In the conv layers, Pytorch can consider weights for any number of channels! So, the main challenge is that we must load images in the format of gray, RGB, or a tensor (In PyTorch we deal with tensors)? For example, if I have FITS images I can not train my image in Pytorch. Is that correct? Because this is an unknown format in Pytorch!

PyTorch doesn’t know anything about image formats and just expects tensors.
If you can load your 16 channels using any library and transform it somehow to a tensor (e.g. from numpy arrays) it will work.
What PyTorch cares about is:

input = torch.randn(1, 16, 224, 224)
first_conv = nn.Conv2d(16, 32, 3, 1, 1)

out = first_conv(input)

You have to make sure input is a tensor with 16 channels.

1 Like

Many thanks for your guidance and time.