Classification using 4-channel images

I have some gray scale and color images with label. I want to combine this gray and color images (4-channel) and run transfer learning using 4-channel images. How to do that?

This post gives you an example of your use case. :slight_smile:

Thanks @ptrblck. What about dataloader part? How I will load two types (Gray & RGB) of image and combine them?

I would recommend to write a custom Dataset class as described here, load the image pairs in the __getitem__ method and concatenate them there as well.

I assume you have some mapping between the RGB and grayscale images.
If so, you could pass the mapping together with the paths for the images to __init__ and load the pairs lazily in __getitem__.

@ptrblck I did what u said and now got this error
image= torch.cat((image_BW, image_RGB), 1)
TypeError: expected Tensor as element 0 in argument 0, but got BmpImageFile

class bothDataset(Dataset):
    '''
    Initilize custom Dataset for image analysis
    It should correlate image with label for training dataset
    '''

    def __init__(self, csv_path, root_dir, transform=None):
        """Init function should not do any heavy lifting, but
            must initialize how many items are available in this data set.
            
            Args:
                csv_path : Path to the csv where classes (labels) and image names are located accordingly
                root_dir : Directory with all images
                transform: Optional transform to be applied on a sample
        """
        
        #df = pd.read_csv(csv_path)#, sep=',',index_col=0)
        self.img_names = pd.read_csv(csv_path)
        self.transform = transform
        self.root_dir = root_dir
        #self.images = read_images(root + "/images")
        #self.labels = read_labels(root + "/labels")

    def __len__(self):
        """return number of points in our dataset"""

        return len(self.img_names)

    def __getitem__(self, idx):
        """ Here we have to return the item requested by `idx`
            The PyTorch DataLoader class will use this method to make an iterable for
            our training or validation loop.
        """

        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name_BW = os.path.join(self.root_dir,
                                self.img_names.iloc[idx, 0])
        image_BW = Image.open(img_name_BW+'.bmp')
        img_name_RGB = os.path.join(self.root_dir,
                                self.img_names.iloc[idx, 1])
        image_RGB = Image.open(img_name_RGB+'.bmp')
        image= torch.cat((image_BW, image_RGB), 1) 
        #image = io.imread(img_name+'.jpg')
        #image = cv2.imread(img_name+'.jpg')
        #image=image_BW
        label = self.img_names.iloc[idx, 2]
        #label = np.array([label])
        #label = label.astype('float')#.reshape(-1, 2)
        #sample = {'image': image, 'label': label}

        if self.transform:
            image = self.transform(image)
            
        return image, label                                   
train_datak = bothDataset(csv_path='RelationBS1.csv',
                               root_dir=test_folder, 
                                  transform=train_transform)
print('Num training images: ', len(train_datak))
batch_size =64
num_workers=0

# prepare data loaders
train_loader = torch.utils.data.DataLoader(train_datak, batch_size=batch_size, 
                                           num_workers=num_workers, shuffle=True)
dataiter = iter(train_loader)
images, labels = dataiter.next()  
print(type(images))
print(images.shape)

Use torchvision.transforms.functional.to_tensor on both images, as they are currently still instances of PIL.Image, while torch.cat expects a list of tensors.

1 Like

@ptrblck Thanks for helping me. you want to say like this?But still same error…:frowning:

        img_name_BW = os.path.join(self.root_dir,
                                self.img_names.iloc[idx, 0])
        image_BW = Image.open(img_name_BW+'.bmp')
        torchvision.transforms.functional.to_tensor(image_BW)
        img_name_RGB = os.path.join(self.root_dir,
                                self.img_names.iloc[idx, 1])
        image_RGB = Image.open(img_name_RGB+'.bmp')
        torchvision.transforms.functional.to_tensor(image_RGB)
        image= torch.cat((image_BW, image_RGB), 1) 

You would have to assign the result back:

image_BW = torchvision.transforms.functional.to_tensor(image_BW)
1 Like

BW & RGB image dimensions are different. Still can I use torch. cat function?
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 1 and 3 in dimension 0
I am trying to concatenate (1,224, 224) and (3, 224, 224)

It should work, if you concatenate these tensors in dim0 via torch.cat((a, b), dim=0).

I am using this code torch.cat((a, b), dim=0). But getting the same error as my image sizes are not same. How to resize image size before concatenate?
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 391 and 389 in dimension 1

Oh, based on your previous post, it seemed both images have a spatial size of 224.
You could use e.g. torchvision.transforms.Resize on both PIL.Images.

@ptrblck Thanks. But now error is
TypeError: pic should be PIL Image or ndarray. Got <class ‘torchvision.transforms.transforms.Resize’>

image_BW = Image.open(img_name_BW+'.bmp')
        image_BW  =transforms.Resize(224, 224)
        image_BW=torchvision.transforms.functional.to_tensor(image_BW)

I converted before applying Resize. Still same error

image_BW=torchvision.transforms.transforms.ToPILImage(image_BW)

Resize is a Python object, so you would need to create an instance:

res = transforms.Resize((224, 224))
image_BW = res(image_BW)
...

or use the functional API again via: torchvision.transforms.functional.resize.

1 Like

It is working now. Thank you so much.

Thanks - very useful. Small questions:

  • why no_grad. Is there reason to freeze the weights?
    Note that I count trainable parameters before and after making the change and get the same answer in both cases.

  • what are these two lines for - test?
    x = torch.randn(10, 4, 224, 224)
    output = model(x)

  1. The weights are reassigned, which shouldn’t be tracked by Autograd as a differentiable operation.

  2. Yes, these lines of code test, if the manipulation of the parameters was successful or if we get a shape mismatch error.