WeightedRandomSampler


(Johan Hansson) #1

Hi!
I have some problem with WeightedRandomSampler, been going over old threds but can’t get my head around it. I am using the following code.

from torch.utils.data import Dataset

class CustomDataset(Dataset):
“”“Face Landmarks dataset.”""

def __init__(self, dataFrame, transform):
    """
    Args:
        csv_file (string): Path to the csv file with annotations.
        root_dir (string): Directory with all the images.
        transform (callable, optional): Optional transform to be applied
            on a sample.
    """
    self.df=df
    self.transform = transform

def __len__(self):
    return len(self.df)

def __getitem__(self, idx):  
    img = Image.open(self.df.iloc[idx]['Path'])
    img = img.convert('RGB')

    name = df.iloc[idx]['Flower']
    sample = {'image': img, 'name': name,'dir':self.df.iloc[idx]['FlowerDir']}
    if self.transform:
        sample['image'] = self.transform(sample['image'])
    return sample['image'], int(sample['dir'])-1

df[‘weight’]=float(0)
for idx in range(len(df)):
Flower = df.iloc[idx][‘Flower’]
weights=float(1 / df[‘Flower’].value_counts()[Flower])
df.at[idx, ‘weight’] = weights

df.head(n=5)
dfList = list(df[‘weight’])
dfList=torch.FloatTensor(dfList)

dataset_length = len(df)

number of subprocesses to use for data loading

num_workers = 0

how many samples per batch to load

batch_size=20

percentage of training set to use as validation

valid_size = 0.2

image_datasets = {
‘training’ : CustomDataset(dataFrame=df,transform=transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),transforms.RandomRotation(30),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])),
‘test’ : CustomDataset(dataFrame=df2,transform=transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),transforms.RandomRotation(30),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])]))}

#weights = 1 / torch.Tensor(df[‘Flower’].value_counts())
#weights = weights.double()

sampler2 = torch.utils.data.sampler.WeightedRandomSampler(dfList, len(df) )

train_loader = torch.utils.data.DataLoader(image_datasets[‘training’], batch_size=batch_size,sampler=sampler2,num_workers=num_workers)

valid_loader = torch.utils.data.DataLoader(image_datasets[‘test’], batch_size=batch_size,num_workers=num_workers)
print(‘done’)

dataiter = iter(train_loader)
images, labels = dataiter.next()

When I try to itterate over the data I have the problem belove


TypeError Traceback (most recent call last)
in ()
1 dataiter = iter(train_loader)
----> 2 images, labels = dataiter.next()
3 print(labels)
4 print(len(labels))

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in next(self)
312 if self.num_workers == 0: # same-process loading
313 indices = next(self.sample_iter) # may raise StopIteration
–> 314 batch = self.collate_fn([self.dataset[i] for i in indices])
315 if self.pin_memory:
316 batch = pin_memory_batch(batch)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in (.0)
312 if self.num_workers == 0: # same-process loading
313 indices = next(self.sample_iter) # may raise StopIteration
–> 314 batch = self.collate_fn([self.dataset[i] for i in indices])
315 if self.pin_memory:
316 batch = pin_memory_batch(batch)

in getitem(self, idx)
20
21 def getitem(self, idx):
—> 22 img = Image.open(self.df.iloc[idx][‘Path’])
23 img = img.convert(‘RGB’)
24

/usr/local/lib/python3.6/dist-packages/pandas/core/indexing.py in getitem(self, key)
1371
1372 maybe_callable = com._apply_if_callable(key, self.obj)
-> 1373 return self._getitem_axis(maybe_callable, axis=axis)
1374
1375 def _is_scalar_access(self, key):

/usr/local/lib/python3.6/dist-packages/pandas/core/indexing.py in _getitem_axis(self, key, axis)
1817 # a list of integers
1818 elif is_list_like_indexer(key):
-> 1819 return self._get_list_axis(key, axis=axis)
1820
1821 # a single integer

/usr/local/lib/python3.6/dist-packages/pandas/core/indexing.py in _get_list_axis(self, key, axis)
1792 axis = self.axis or 0
1793 try:
-> 1794 return self.obj._take(key, axis=axis, convert=False)
1795 except IndexError:
1796 # re-raise with different error message

/usr/local/lib/python3.6/dist-packages/pandas/core/generic.py in _take(self, indices, axis, convert, is_copy)
2148 new_data = self._data.take(indices,
2149 axis=self._get_block_manager_axis(axis),
-> 2150 verify=True)
2151 result = self._constructor(new_data).finalize(self)
2152

/usr/local/lib/python3.6/dist-packages/pandas/core/internals.py in take(self, indexer, axis, verify, convert)
4262 new_labels = self.axes[axis].take(indexer)
4263 return self.reindex_indexer(new_axis=new_labels, indexer=indexer,
-> 4264 axis=axis, allow_dups=True)
4265
4266 def merge(self, other, lsuffix=’’, rsuffix=’’):

/usr/local/lib/python3.6/dist-packages/pandas/core/internals.py in reindex_indexer(self, new_axis, indexer, axis, fill_value, allow_dups, copy)
4148 new_blocks = [blk.take_nd(indexer, axis=axis, fill_tuple=(
4149 fill_value if fill_value is not None else blk.fill_value,))
-> 4150 for blk in self.blocks]
4151
4152 new_axes = list(self.axes)

/usr/local/lib/python3.6/dist-packages/pandas/core/internals.py in (.0)
4148 new_blocks = [blk.take_nd(indexer, axis=axis, fill_tuple=(
4149 fill_value if fill_value is not None else blk.fill_value,))
-> 4150 for blk in self.blocks]
4151
4152 new_axes = list(self.axes)

/usr/local/lib/python3.6/dist-packages/pandas/core/internals.py in take_nd(self, indexer, axis, new_mgr_locs, fill_tuple)
1219 fill_value = fill_tuple[0]
1220 new_values = algos.take_nd(values, indexer, axis=axis,
-> 1221 allow_fill=True, fill_value=fill_value)
1222
1223 if new_mgr_locs is None:

/usr/local/lib/python3.6/dist-packages/pandas/core/algorithms.py in take_nd(arr, indexer, axis, out, fill_value, mask_info, allow_fill)
1368 if out is None:
1369 out_shape = list(arr.shape)
-> 1370 out_shape[axis] = len(indexer)
1371 out_shape = tuple(out_shape)
1372 if arr.flags.f_contiguous and axis == arr.ndim - 1:

TypeError: len() of unsized object

Sorry if a posted a long question but thought it was better to show the full code :slight_smile: Cheers!


(Johan Hansson) #2

The problem is probably in my customeDataset, solved by using ImageFolder instead.