How to Binarize datasets in PyTorch?

Hey, I am binarizing a dataset in pytorch, and I am doing it in my test dataset, here is the code -

from google.colab import drive
drive.mount('/content/drive')
data = "/content/drive/My Drive/AMD_new"


import torch
import helper
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
import torchvision.models as models

from torchvision import datasets ,transforms


#Changning the transform of the data-
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),
                                     transforms.RandomResizedCrop(224),
                                     # transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

transform_test = transforms.Compose([transforms.RandomHorizontalFlip(),
                                     transforms.RandomResizedCrop(224),
                                     # transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

# choose the training and test datasets
train_data = datasets.ImageFolder(data+"/train", transform=transform_train)
test_data = datasets.ImageFolder(data+"/val", transform = transform_test)

dataloader_train = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
dataloader_test = torch.utils.data.DataLoader(test_data, batch_size=32, num_workers=2)

but when I am using these line at last, it is giving me error -

# Binarize the output
dataloader_test = label_binarize(dataloader_test, classes=[0, 1, 2, 3])
nb_classes = dataloader_test.shape[1]

Any lead on where this error comes from, as per I know it is not possible to use scikit with out binarize the data. Error is

TypeError: Singleton array array(<torch.utils.data.dataloader.DataLoader object at 0x7fc3048321d0>,
dtype=object) cannot be considered a valid collection.
It means my dataloader_test is a single object right, but how, it has four class and each class has some 10 photos. Thank you or your help.

Not sure what you mean, maybe you should try target_transform or write your own dataset

Sorry for the late reply. I am trying to get ROC curve using Scikit learn, I have fpr and tpr too, but to print the ROC, I need to binarize the dataoader_test. How can I do that, when I trying this, resulting this -

E: Package 'python-software-properties' has no installation candidate
··········
fuse: mountpoint is not empty
fuse: if you are sure this is safe, use the 'nonempty' mount option
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-6-e1dbe1e9866d> in <module>()
     58 
     59 # Binarize the output
---> 60 dataloader_test = label_binarize(dataloader_test, classes=[0, 1, 2, 3])
     61 nb_classes = dataloader_test.shape[1]

/usr/local/lib/python3.6/dist-packages/sklearn/preprocessing/label.py in label_binarize(y, classes, neg_label, pos_label, sparse_output)
    579         # XXX Workaround that will be removed when list of list format is
    580         # dropped
--> 581         y = check_array(y, accept_sparse='csr', ensure_2d=False, dtype=None)
    582     else:
    583         if _num_samples(y) == 0:

/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, warn_on_dtype, estimator)
    575     shape_repr = _shape_repr(array.shape)
    576     if ensure_min_samples > 0:
--> 577         n_samples = _num_samples(array)
    578         if n_samples < ensure_min_samples:
    579             raise ValueError("Found array with %d sample(s) (shape=%s) while a"

/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py in _num_samples(x)
    140         if len(x.shape) == 0:
    141             raise TypeError("Singleton array %r cannot be considered"
--> 142                             " a valid collection." % x)
    143         # Check that shape is returning an integer or default to len
    144         # Dask dataframes may not return numeric shape[0] value

TypeError: Singleton array array(<torch.utils.data.dataloader.DataLoader object at 0x7fc3048321d0>,
      dtype=object) cannot be considered a valid collection.

You know what need to be done? Follwing this tutorial - https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html

You need to first get all data out by:

all_label = torch.zeros(N,K)
for i,(data, label) in enumerate(dataloader):
      all_label[i]=label
label_binarize(all_label)

I’ll advise you to do it in Dataset, it’s a very easy operation, no need to make it complicated.

So, basically as I am not very experienced, do I have to initialize N and K with some random integers? And I have already done this loop state in my later part like this -

for epoch in range(epochs):
  
  running_loss = 0
  model.train()
  for images, labels in dataloader_train:
    
    #steps += 1
    images, labels = images.to(device), labels.to(device)
    
    optimizer.zero_grad()
    
    output = model.forward(images)
    p = torch.nn.functional.softmax(output, dim=1)
    prediction = torch.argmax(p, dim=1)
    #loss = torch.nn.functional.nll_loss(torch.log(p), y)
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()

so here in the first cell, where I have written my dataloader code, I have to write like this -

all_label = torch.zeros(N,K)
for i,(image, labels) in enumerate(dataloader_test):
      all_label[i]=labels
label_binarize(all_label)

am I right? Thanks.