Error in Dataloader: AttributeError: 'int' object has no attribute 'numel'

Hi, I am new to ML (and here), and this error has been haunting me for days. :skull_and_crossbones: :skull_and_crossbones: :skull_and_crossbones:
I have checked (almost) all the solutions online but none solves my problem.
Any help will be very much appreciated.

I am using a CNN to classify images into 11 classes.
The challenge is that I have only a portion of the training set which are labelled images(about 3000)
, the others are unlabelled (about 6000).
I am trying to do the labelling before each training epoch, and if the probability for a data is higher than a certain threshold (eg. 0.7), I’ll add it (the image) and its corresponding label into two lists

def get_pseudo_labels(dataset, model, threshold=0.7):
    ...
    samples = []
    pseudolabels= []
    for batch in dataloader: # the labelled portion 
        img, _ = batch.      # Size [128(batch size), 3, 128, 128] 
        with torch.no_grad():
            logits = model(img.to(device))
        probs = softmax(logits)          # Size [128, 11]
        (max_probs, max_indices) = torch.max(probs, dim = 1)     
        for i, (max_prob, max_idx) in enumerate(zip(max_probs, max_indices)): 
            if max_prob > threshold: 
              samples.append(img[i].cpu())      # img[i] tensor of size (3, 128, 128)
              pseudolabels.append(max_idx.cpu())   # 1 of the 11 classes  # img[i] tensor of size (1)
              #  have check the two lists' lengths, they are the same

    if len(samples) > 1:        
      dataset = MyDataset(samples, pseudolabels)  # pseudo set
    else:
      dataset = None

then send the lists to instantiate a MyDataset class object like this:

from torch.utils.data import Dataset
class MyDataset(Dataset):
    def __init__(self, X, y=None):
        # Stacking tensors into one tensor
        self.data = torch.stack(X)  # size: torch.Size([98, 3, 128, 128]). # 3, 128, 128 is one img size 
        self.label = torch.stack(y)  # size: torch.Size([98]) # the 98 is an arbitrary number of the data with prob > 0.7
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]
    def __len__(self):
        return len(self.data)

The code for training:

for epoch in range(n_epochs):

    if do_semi:   # if do semi-supervised labelling, i.e. feed the pseudo labels back to training set 
        pseudo_set = get_pseudo_labels(unlabeled_set, model)
        if pseudo_set != None:
          concat_dataset = ConcatDataset([train_set, pseudo_set]) 
          train_loader = DataLoader(concat_dataset, batch_size=batch_size, shuffle= True , num_workers=2, pin_memory=True)
        else:
          train_loader = DataLoader(train_set, batch_size=batch_size, shuffle= True , num_workers=2, pin_memory=True)
    # ---------- Training ----------
    model.train()
    # Iterate the training set by batches.
    for idx, batch in enumerate(train_loader):
         imgs, labels = batch   
         # if the batch has no problem, it yields:  
         # imgs shape: torch.Size([128, 3, 128, 128]) <class 'torch.Tensor'>
         # labels shape: torch.Size([128]) <class 'torch.Tensor'>

Here, in the beginning of the for loop, is where the error occurs.

AttributeError: Caught AttributeError in DataLoader worker process 1.
...
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py", line 52, in default_collate
    numel = sum([x.numel() for x in batch])
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py", line 52, in <listcomp>
    numel = sum([x.numel() for x in batch])
AttributeError: 'int' object has no attribute 'numel'

The error looks like I have passed something which is not a tensor into the dataset. I have checked almost everywhere; the pseudo set is all fine, the error occurs in the duration of the dataloader is producing a batch. The weird part is, this error does not appear every time a pseudo set is added into training, it comes in the latter period but sooner or later, it definitely comes. Please, could anyone tell me where could it go wrong? Or recommend a better way to wrap up the pseudo set?
Again, any help is very VERY MUCH appreciated. :sob: :sob: :sob: :sob: :sob: :sob:

Update:
Don’t know why but I kind of solve the problem…?
Instead of sending two lists into the dataset, I send two numpy arrays and alter the dataset class as the following:

class MyDataset(Dataset):
    def __init__(self, X, y=None):
        self.data = torch.from_numpy(X).float()
        if y is not None:
            y = y.astype(np.int)
            self.label = torch.LongTensor(y)
        else:
            self.label = None
    def __getitem__(self, idx):
        if self.label is not None:
            return self.data[idx], self.label[idx] # -> a batch 
        else:
            return self.data[idx]

    def __len__(self):
        return len(self.data)

And then I define my own simple collate function that works inside the dataloader (its work is outputting batches) like this:

def my_collate(batch):
    """Define collate_fn myself because the default_collate_fn throws errors like crazy"""
    # item: a tuple of (img, label)
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    data = torch.stack(data)
    target = torch.LongTensor(target)
    return [data, target]

then use mine instead of the default_collate_fn:

train_loader = DataLoader(concat_dataset, batch_size=batch_size, collate_fn = my_collate, shuffle= True, num_workers=2, pin_memory=True)

Then it works. At least for the following training I don’t get errors anymore. Still dunno what causes that original error but hope ppl with the same problem find this useful.

I think it’s because of the difference between the output of __getitem__ method of Concated two datasets( train_set and pseudo_set). __getitem__ of train_set yields (image, int label) while __getitem__ of pseudo_set yields (image, tensor label). After two datasets are concated, a set of indices of the final dataset yields a list of output of __getitem__. But in this case the outputs’ structures probably don’t agree.( (tensor,int) differs from (tensor, tensor)).And Dataloader doesn’t work. However, if you’re lucky enough to have all outputs of identical structure, it will work for a while. The new collate function you define apply longtensor to all targets, which cancels the difference between two kinds of outputs, I guess.

import torch
a = [1,torch.tensor(2)]
print(torch.LongTensor(a))

And this will yield tensor([1, 2]).

1 Like

Exactly. I figured it out after posting the first reply(because the code of the original dataset is kinda not straightforward to me, it is a class of datasetfolder, and I couldn’t really look into it to see its datatype directly). Thank you for pointing it out!! :sob: :orange_heart:

You can read the source code of the class that you import your data, just like DatasetFolder.
DatasetFolder use a list as labels, so you need to use the same type on your own labels.