I’m getting an error when I use a DataLoader based on split dataset through torch.utils.data.random_split()
.
I have tried creating a dataloader with the un-split dataset and everything works fine. So I assume it’s something do with the splitting.
I have a custom dataset defined as:
class AntsDataset(Dataset):
def __init__(self, root_dir, csv_file, transform=None):
"""
Args:
csv_file (string): Path to the csv_file with rotations
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.rotations = pd.read_csv(csv_file,header=None)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.rotations)
def __getitem__(self, idx):
#import ipdb; ipdb.set_trace()
img_name = os.path.join(self.root_dir,
self.rotations.iloc[idx, 0])
image = plt.imread(img_name,format='RGB')
rotation = self.rotations.iloc[idx, 1].astype('float')
if self.transform is not None:
image=self.transform(image)
return (image, rotation)
I then create a dataset, split and form a data_loader:
ants_dataset=AntsDataset(ants1_root_dir, ants1_rot_file,
transform=transforms.Compose([transforms.ToPILImage(),
transforms.Resize((120,120)),
transforms.RandomCrop(size=100, pad_if_needed=True),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.07),
transforms.ToTensor()]))
dataloader=torch.utils.data.DataLoader(ants_dataset,
batch_size=10, shuffle=True)
train_length=int(0.7* len(ants_dataset))
test_length=len(ants_dataset)-train_length
train_dataset,test_dataset=torch.utils.data.random_split(ants_dataset,(train_length,test_length))
dataloader_train=torch.utils.data.DataLoader(train_dataset,
batch_size=10, shuffle=True)
for batch_idx, (data,rotations) in enumerate(dataloader_train):
print(rotations)
When I try to loop over the dataloader I get the following error:
--------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-82-f629e71651de> in <module>()
1 dataloader_train=torch.utils.data.DataLoader(train_dataset,
2 batch_size=10, shuffle=True)
----> 3 for batch_idx, (data,rotations) in enumerate(dataloader_train):
4 print(rotations)
~/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
312 if self.num_workers == 0: # same-process loading
313 indices = next(self.sample_iter) # may raise StopIteration
--> 314 batch = self.collate_fn([self.dataset[i] for i in indices])
315 if self.pin_memory:
316 batch = pin_memory_batch(batch)
~/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py in <listcomp>(.0)
312 if self.num_workers == 0: # same-process loading
313 indices = next(self.sample_iter) # may raise StopIteration
--> 314 batch = self.collate_fn([self.dataset[i] for i in indices])
315 if self.pin_memory:
316 batch = pin_memory_batch(batch)
~/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataset.py in __getitem__(self, idx)
101
102 def __getitem__(self, idx):
--> 103 return self.dataset[self.indices[idx]]
104
105 def __len__(self):
<ipython-input-44-01f6586a3276> in __getitem__(self, idx)
20 #import ipdb; ipdb.set_trace()
21 img_name = os.path.join(self.root_dir,
---> 22 self.rotations.iloc[idx, 0])
23 image = plt.imread(img_name,format='RGB')
24 rotation = self.rotations.iloc[idx, 1].astype('float')
~/anaconda3/envs/pytorch/lib/python3.6/site-packages/pandas/core/indexing.py in __getitem__(self, key)
1470 except (KeyError, IndexError):
1471 pass
-> 1472 return self._getitem_tuple(key)
1473 else:
1474 # we by definition only have the 0th axis
~/anaconda3/envs/pytorch/lib/python3.6/site-packages/pandas/core/indexing.py in _getitem_tuple(self, tup)
2011 def _getitem_tuple(self, tup):
2012
-> 2013 self._has_valid_tuple(tup)
2014 try:
2015 return self._getitem_lowerdim(tup)
~/anaconda3/envs/pytorch/lib/python3.6/site-packages/pandas/core/indexing.py in _has_valid_tuple(self, key)
220 raise IndexingError('Too many indexers')
221 try:
--> 222 self._validate_key(k, i)
223 except ValueError:
224 raise ValueError("Location based indexing can only have "
~/anaconda3/envs/pytorch/lib/python3.6/site-packages/pandas/core/indexing.py in _validate_key(self, key, axis)
1965 l = len(self.obj._get_axis(axis))
1966
-> 1967 if len(arr) and (arr.max() >= l or arr.min() < -l):
1968 raise IndexError("positional indexers are out-of-bounds")
1969 else:
TypeError: len() of unsized object