Loading an Image Dataset where the Target (or Ground Truth) are also Images

I’m working on a project with some friends on a computer vision model where we have an image dataset and we transform the images into black and white and keep the target or ground truth as the original color image. In essence, it’s an image colorization model where we learn parameters to colorize black and white images.

I’ve seen several tutorials regarding ImageFolder and DataLoader but the main problem is that although DataLoader loads the images, the Y_train part (that is, the target) are labels. Meaning that the name of the folder in which the dataset is stored (say, for example, ‘/Dogs/’) will be the target label. However, in my model, the final output should also be an image only.

Hence I want to know what I can do to take the image dataset and make a black and white copy of each image to keep as input for the CNN later on and have the original color image as the final target output image.

For reference, we are using the ~500 MB Places2 validation set as our main data.

If you do the transformation to black and white on the fly, you can just do it in your training loop or add it as an additional transform.
That said, it’s very easy and maybe instructive to just implement your own dataset loading pairs of images, too (using the transforms of torchvision). Really, all you need to do is load the images, convert them to tensors and return a pair of them. PyTorch will take care of the rest.

Best regards


Thanks for the response, Tom! I wanted to ask if you could link me to some resources as to:

a) Create your own dataset loading class. I did look up the documentation tutorial but found it a bit unwieldy for me as beginner.
b) A comprehensive (and preferably beginner-friendly) understanding of what exactly the ImageFolder and DataLoader classes do.

I’m sorry if I’m missing something very obvious but I’m really not able to understand their exact purpose, the difference between them, and how and why to use them. Tom, if you could explain this to me as an ML amateur, I would be ever so delighted! And if possible, explaining in terms of my project use-case would probably make it more understandable!


It is really as simple as defining a class (preferable a subclass of Dataset) that has __len__ and __getitem__. The len is typically trivial and getitem should return a single record from the dataset for any number 0…len-1. You might look at the data loading tutorial for inspiration. You can really return anything from the dataset, but so a pair of tensors would probably fit best for your application.
Our book also has some elementary examples.
The image folder really only does the same and I’d say you’re ready to just dive into the source after the tutorial.

Best regards


PS I might add that a considerable amount of magic happens in the dataloader to get batching, parallel loading from multiple workers etc., but you wouldn’t need to concern yourself with it at this point.