How can we work with one channel images?

Hello everyone,
I’m trying to train images with one channel (That what I want for specific reason), but PyTorch refuses to deal with that. Is there any way to deal with one channel images? ’

This is the error that I’ve got.

RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]

Could you post a small code snippet or the line of code which throws this error?
If you would like to use a pretrained model, you should pass an image containing 3 channels.
Using grayscale images you could either repeat the channel dimension manually, use torchvision.transforms.Grayscale(num_output_channels=3), or use a conv layer to “learn” the mapping from 1 to 3 channels at the beginning of the model.

1 Like

@ptrblck

This was the line that throws this error.

  dataiter = iter(train_loader)
  images, labels = dataiter.next()

let me explain what I’m trying to do, and I would be happy if you have any suggestion.

I’m working with different color channels(RGB, HSV, LAB, YIQ …) and I have separated all the channels to see which ones are the most effective/powerful channels (Which in our case will give a higher accuracy). This peace of information will help me in my project!
So, what I need is to train only one channel, without duplication. Is that doable?

Sure! Make sure the image tensors in your Dataset have the shape [1, height, width] and set the number of input channels to one in the first conv layer of your model.

Yes, but this size throws the error that I have mentioned above. It doesn’t accept that size. Add to this, it upload the the image as a 3 channel image but it’s actually a one channel! So I have to convert it.

For example, this function doesn’t work:

    def visulaizeTrainData(self, class_names):
        dataiter = iter(self.train_loader)
        images, labels = dataiter.next()
        images = images.numpy() 
        inp = images[1].copy().T
        fig = plt.figure(figsize=(25, 4))
        for idx in np.arange(images.shape[0]):
            ax = fig.add_subplot(2, (images.shape[0])/2, idx+1, xticks=[], yticks=[])
            inp = np.transpose(images[idx], (1, 2, 0))
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            inp = std * inp + mean
            inp = np.clip(inp, 0, 1)
            plt.imshow(inp)
            ax.set_title(class_names[labels[idx]])

it throws the same error

and this function as well:

 def trainNetwork(self,n_epochs , model_save_name):
        self.n_epochs = n_epochs
        train_on_gpu = torch.cuda.is_available()
        valid_loss_min = np.Inf # track change in validation loss
        for epoch in range(1, n_epochs+1):
            train_loss = 0.0
            valid_loss = 0.0
            #Train
            model.train()
            for data, target in self.train_loader:
                if train_on_gpu:
                    data, target = data.cuda(), target.cuda()
                self.data = data    
                self.optimizer.zero_grad()
                self.output = model(data)
                loss = self.criterion(self.output, target)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()*data.size(0)
                
        ###################################################   
            #Validation
            model.eval()
            for data, target in self.valid_loader:
                if train_on_gpu:
                    data, target = data.cuda(), target.cuda()
                output = model(data)
                loss = criterion(output, target)
                valid_loss += loss.item()*data.size(0)
                
            train_loss = train_loss/len(self.train_loader.dataset)    
            valid_loss = valid_loss/len(self.valid_loader.dataset)
            
            print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch, train_loss, valid_loss))
            
            
            if valid_loss <= valid_loss_min:
                print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                            valid_loss_min,valid_loss))
                path = F"/content/gdrive/My Drive/Colab Notebooks/New Trials/colorChannelsWeights/{model_save_name}" 
                torch.save(model.state_dict(), path)            
                valid_loss_min = valid_loss