Best way to deal with 1 channel images

I have a lot of one channel images and wanted to make use of transfer learning. All the pretrained models that are available in torchvision expects 3 channel images, so for time being I have added padding of zeros to make 1 channel image,3 channel image but seems that it’s not helping as my loss isn’t going down. Is there any better way to deal with this ? I am using ResNet50 for this purpose.

2 Likes

Instead of filling the additional channels with zeros you could convert the grayscale image to an RGB one with Image.open('PATH').convert('RGB').
This might work in most use cases.
Alternatively, you could keep your grayscale image and add a conv layer in front of your model, which outputs 3 channels. For this approach you would need to train this layer though.

4 Likes

I dont’t have images, I have a dataframe which contains strokes. So I’m using this, so not sure how to use your first approach. The function which converts strokes to images is:

def get_ims(raw_strokes):
    image = Image.new("P", (28,28), color=255)
    image_draw = ImageDraw.Draw(image)

    for stroke in eval(raw_strokes):
        for i in range(len(stroke[0])-1):

            image_draw.line([stroke[0][i], 
                             stroke[1][i],
                             stroke[0][i+1], 
                             stroke[1][i+1]],
                            fill=0, width=6)
    return np.array(image)
1 Like

I’m sure someone will be able to tell you if there’s a smarter way to do this, but I think what you can do is:

RGB_image = [image] + [image] + [image]

RGB_image = [image] * 3 :wink:

1 Like

You can make your own transform that accomplishes it for you:

transform_rgb = transforms.Lambda(lambda image: image.convert('RGB'))

I’m not sure how performant this is, though, given the lambda operation and how it’s on the PIL Image and not the tensor.

4 Likes

What is the difference between:

  • Converting the image to an RGB Image.open('PATH').convert('RGB')
    and
  • RGB_image = [image] * 3 ?

How did the Pytorch’s pretrained models manage this?