ImbalancedDatasetSampler DataLoader Sampler

I have this sampler class in order to pass a sampler into my DataLoader. However, when I keep getting an AttributeError that I’ve posted below the code. Not sure why this is happening.

class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
    """Samples elements randomly from a given list of indices for imbalanced dataset
    Arguments:
        indices (list, optional): a list of indices
        num_samples (int, optional): number of samples to draw
        callback_get_label func: a callback-like function which takes two arguments - dataset and index
    """

    def __init__(self, dataset, indices=None, num_samples=None, callback_get_label=None):
                
        # if indices is not provided, 
        # all elements in the dataset will be considered
        self.indices = list(range(len(dataset))) \
            if indices is None else indices

        # define custom callback
        self.callback_get_label = callback_get_label

        # if num_samples is not provided, 
        # draw `len(indices)` samples in each iteration
        self.num_samples = len(self.indices) \
            if num_samples is None else num_samples
            
        # distribution of classes in the dataset 
        label_to_count = {}
        for idx in self.indices:
            label = self._get_label(dataset, idx)
            if label in label_to_count:
                label_to_count[label] += 1
            else:
                label_to_count[label] = 1
                
        # weight for each sample
        weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
                   for idx in self.indices]
        self.weights = torch.DoubleTensor(weights)

    def _get_label(self, dataset, idx):
        return dataset.train_labels[idx].item()
                
    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(
            self.weights, self.num_samples, replacement=True))

    def __len__(self):
        return self.num_samples

Error:

     21 
     22
---> 23         train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, sampler=ImbalancedDatasetSampler(train_data))
     24 

<ipython-input-13-87caa7d6b353> in __init__(self, dataset, indices, num_samples, callback_get_label)
     26         label_to_count = {}
     27         for idx in self.indices:
---> 28             label = self._get_label(dataset, idx)
     29             if label in label_to_count:
     30                 label_to_count[label] += 1

<ipython-input-13-87caa7d6b353> in _get_label(self, dataset, idx)
     38 
     39     def _get_label(self, dataset, idx):
---> 40         return dataset.train_labels[idx].item()
     41 
     42     def __iter__(self):

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataset.py in __getattr__(self, attribute_name)
     81             return function
     82         else:
---> 83             raise AttributeError
     84 
     85     @classmethod

What is dataset? Possible reasons are

  • dataset does not have an attribute train_labels
  • dataset.train_labels[idx] has no attribute item
  • Any library named dataset? (it is not clear)

Please check the above

At each index, the dataset is a tuple of (X,Y) where X is a float32 feature vector and Y is the label for that vector

I guess there is no attribute train_labels in your train_data
Does your train_data inherits torch.utils.data. Dataset?

Oh now I understand, I switched it to dataset[idx][1] and it worked. Thank you!