I have this sampler class in order to pass a sampler into my DataLoader. However, when I keep getting an AttributeError that I’ve posted below the code. Not sure why this is happening.
class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
"""Samples elements randomly from a given list of indices for imbalanced dataset
Arguments:
indices (list, optional): a list of indices
num_samples (int, optional): number of samples to draw
callback_get_label func: a callback-like function which takes two arguments - dataset and index
"""
def __init__(self, dataset, indices=None, num_samples=None, callback_get_label=None):
# if indices is not provided,
# all elements in the dataset will be considered
self.indices = list(range(len(dataset))) \
if indices is None else indices
# define custom callback
self.callback_get_label = callback_get_label
# if num_samples is not provided,
# draw `len(indices)` samples in each iteration
self.num_samples = len(self.indices) \
if num_samples is None else num_samples
# distribution of classes in the dataset
label_to_count = {}
for idx in self.indices:
label = self._get_label(dataset, idx)
if label in label_to_count:
label_to_count[label] += 1
else:
label_to_count[label] = 1
# weight for each sample
weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
for idx in self.indices]
self.weights = torch.DoubleTensor(weights)
def _get_label(self, dataset, idx):
return dataset.train_labels[idx].item()
def __iter__(self):
return (self.indices[i] for i in torch.multinomial(
self.weights, self.num_samples, replacement=True))
def __len__(self):
return self.num_samples
Error:
21
22
---> 23 train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, sampler=ImbalancedDatasetSampler(train_data))
24
<ipython-input-13-87caa7d6b353> in __init__(self, dataset, indices, num_samples, callback_get_label)
26 label_to_count = {}
27 for idx in self.indices:
---> 28 label = self._get_label(dataset, idx)
29 if label in label_to_count:
30 label_to_count[label] += 1
<ipython-input-13-87caa7d6b353> in _get_label(self, dataset, idx)
38
39 def _get_label(self, dataset, idx):
---> 40 return dataset.train_labels[idx].item()
41
42 def __iter__(self):
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataset.py in __getattr__(self, attribute_name)
81 return function
82 else:
---> 83 raise AttributeError
84
85 @classmethod