Binary Classification of MNIST with pytorch

Here I have a simple 2-layers fully connected neural network, but I am not sure how I can convert input images to binary form in pytorch? Thank you in advance.

Data_tr = datasets.MNIST('../data', train=True, download=True,
                             transforms.Normalize((0.1307,), (0.3081,))]))

train_set =, batch_size=batch_size, shuffle=True)
Data_ts = datasets.MNIST('../data', train=False, transform=transforms.Compose([
    transforms.Normalize((0.1307,), (0.3081,))]))

test_set =, batch_size=batch_size, shuffle=True)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 64)
        self.fc2 = nn.Linear(64, 10)
        #W1 = list(self.fc1.parameters())
        #print("w1:", len(W1[1]))

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

net = Net()

By binary form, do you mean thresholding images(only 0 and 1) rather than using them as grayscale images?
If that’s the case, I don’t think transforms has a function to threshold an image. You can write a custom dataset class which converts the image to binary form

class BinaryMNIST(Dataset):
    def __init__(self):
        self.images , self.labels = ##using datasets.MNIST
    def __getitem__(self, idx):
        image, label = self.images[idx], self.labels[idx]
        image[image < 127] = 0
        image[image >= 127] = 255
        # do the necessary transforms ...
        return image, label

Thank you for your response.

Sorry I am a beginer and I have some additional questions:

  1. What is Dataset in the first line? Should I write the train_set instead? or should I define a new dataset?

  2. What do you mean “using datasets.MNIST” in the def__init__ function? According to my code, I wrote self.images , self.labels = Data_tr but I got an error.

  1. ‘Dataset’ here is an abstract class defined in which should be subclassed when creating a custom dataset class -
binary_mnist = BinaryMNIST()
train_loader =, batch_size=batch_size, shuffle=True)
  1. You can do dir(Data_tr) to check for the attributes. It has two variables - train_data and train_labels.
    Assign them accordingly inside init method.
self.images, self.labels = Data_tr.train_data, Data_tr.train_labels

Check out this link for dataset/dataloader tutorial:

Thank you very much. It is solved.

1 Like

hello @mailcorahul

do you think normalizing would make sense if I binarize the MNIST set?

thanks in advance!