Change 3 channel to 1 channel

Here is my code and my error picture

%cd
!mkdir -p /root/data
%cd /root/data
!pip install -q -U opencv-python
!pip install -q pandas
!pip install numpy==1.15.0
!pip install torch
!pip3 install torchvision
!pip install --upgrade torch torchvision

from google.colab import drive
drive.mount('/content/gdrive')

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data as data
import torchvision
from torchvision import transforms
import cv2
from matplotlib import pyplot as plt
import torch.optim as optim

from google.colab import files
uploaded = files.upload()

items = os.listdir('/content')
print (items) 

for each_image in items:
  if each_image.endswith(".jpg"):
    print (each_image)
    full_path = "/content/" + each_image
    dest_path = "/root/data/" + each_image
    print (full_path)
    image = cv2.imread(full_path)
    im_gray = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
    im_blur=cv2.GaussianBlur(im_gray,(5,5),0)
    im,thre=cv2.threshold(im_blur,180,255,cv2.THRESH_BINARY_INV) 
    thre2=cv2.resize(thre,(28,28),interpolation=cv2.INTER_AREA)
    plt.subplot(2,3,6)
    plt.tight_layout()  
    plt.imshow(thre2)
    plt.xticks([])
    plt.yticks([])
    plt.show()  
    cv2.imwrite(dest_path,thre2)

!pip install -U -q PyDrive

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

from googleapiclient.http import MediaFileUpload
from googleapiclient.discovery import build
drive_service = build('drive','v3')

def save_file_to_drive(name,path):
  file_metadata={
      'name':name,
      'mimeType': 'application/octet-stream'
  }
  media= MediaFileUpload(path,mimetype='application/octet-stream',resumable=True)
  created=drive_service.files().create(body=file_metadata,media_body=media,fields='id').execute()
  return created


import zipfile
!zip -r /root/image.zip  /root/data

 save_file_to_drive('image.zip','/root/image.zip')


n_epochs=3
batch_size_train=32
batch_size_test=89
learning_rate=0.01
momentum=0.5
log_interval=10

random_seed=1
torch.backends.cudnn.enabled=False
torch.manual_seed(random_seed)

data_transforms = transforms.Compose([    
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225] )])

%cd /content/gdrive/'My Drive'/root/data

train_datasets = torchvision.datasets.ImageFolder(root= "/content/gdrive/My Drive/root/", transform=data_transforms)

import torch.utils.data as data
train_data_loader = data.DataLoader(train_datasets, batch_size=batch_size_train, shuffle=True,  num_workers=4)
test_data = torchvision.datasets.ImageFolder(root="/content/gdrive/My Drive/root/", transform=data_transforms)
test_data_loader = data.DataLoader(test_data, batch_size=batch_size_test, shuffle=True, num_workers=4)

examples=enumerate(test_data_loader)
batch_idx, (example_data, example_targets) = next(examples)

import matplotlib.pyplot as plt

fig = plt.figure()
for i in range(12):
  plt.subplot(3,4,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(example_targets[i]))
  plt.xticks([])
  plt.yticks([])
fig

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)

train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_data_loader.dataset) for i in range(n_epochs + 1)]

def train(n_epochs):
  network.train()
  for batch_idx, (data, target) in enumerate(train_data_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        n_epochs, batch_idx * len(data), len(train_data_loader.dataset),
        100. * batch_idx / len(train_data_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append((batch_idx*64) + ((n_epochs-1)*len(train_data_loader.dataset)))
      torch.save(network.state_dict(), 'model.pth')
      torch.save(optimizer.state_dict(), 'optimizer.pth')

def test():
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_data_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_data_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_data_loader.dataset),
    100. * correct / len(test_data_loader.dataset)))

test()
for n_epochs in range(1, n_epochs + 1):
  train(n_epochs)
  test()

You could set in_channels=3 for self.conv1 or alternatively convert your images to grayscale using:

data_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225] )])

PS: I’ve formatted your code for better readability. You can add code snippets using three backticks ``` :wink:

2 Likes