What is color axis?

i found that there are some images/tensors in my dataset that result in torch.Size([4, 224, 224]) instead of torch.Size([3, 224, 224]), resulting in error RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 4 in dimension 1. In the comment of pytorch it says that they swap color axis ©. now I want to of course change it from 4 to 3. what is the keyword for this? i use search machine with python image color axis but it spits out unrelated. is it related to rgb? so the image with size 4 in dimension 1 is not rgb?

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image = sample['image']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image2 = image.transpose((2, 0, 1)).copy()
        tfm_image = torch.from_numpy(image2)
        
        if image2.shape[0] != 3:
            print("shape image = ", image.shape, " shape image2 = ", image2.shape, " tfm_image shape = ", tfm_image.shape)
        
        return {'image': tfm_image,
                'species_id': sample['species_id']}
1 Like

The axis swap is most likely right.
I guess you have 4 channels, because your image might have an additional alpha channel.
Usually it’s the last channel, so you could just remove it with:

image2 = image2[:3]

Try to slice the fourth channel and have a look at the values. Probably they are all ones.

2 Likes

thank you very much.