Dear All,
This relates to one of my earlier posts (Custom data loader and label encoding with CIFAR-10 - #3 by QuantScientist ), but it deserves a new thread.
When I iterate the Data set during training, like so:
for i, (inputs, labels) in enumerate(train_loader):
print (type(inputs))
print (type(labels))
print (“Label:” + str(labels))
The labels return 4 items as a tuple instead of only one Item (for instance, “dog”). I understand that the problem is with my data loader however I can not seem to figure out which line of code is the culprit.
The full code is here:
https://github.com/QuantScientist/Deep-Learning-Boot-Camp/blob/master/day%2002%20PyTORCH%20and%20PyCUDA/PyTorch/21-PyTorch-CIFAR-10-Custom-data-loader-from-scratch.ipynb
And this is the exception:
Many thanks for any help!
chenyuntc
(Yun Chen)
August 28, 2017, 3:14am
2
turn your label to a number,i.e frog->0,truck->1
, and then turn them to a tensor.
Thanks Chen, I found this yesterday, used defaultdict(LabelEncoder)
and updated the Notebooke: https://github.com/QuantScientist/Deep-Learning-Boot-Camp/blob/master/day%2002%20PyTORCH%20and%20PyCUDA/PyTorch/21-PyTorch-CIFAR-10-Custom-data-loader-from-scratch.ipynb
I was surprised since PyTorch can work with class labels which are not encoded, but when you write a Custom dataset it actually forces you to do so (e.g. encode)
chenyuntc
(Yun Chen)
August 28, 2017, 8:05am
4
Actually Dataset and DataLoader are not so complicated. I would strongly advise you to read
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class ImageFolder(data.Dataset):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
This file has been truncated. show original
and
if pin_memory:
batch = pin_memory_batch(batch)
except Exception:
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
out_queue.put((idx, batch))
numpy_type_map = {
'float64': torch.DoubleTensor,
'float32': torch.FloatTensor,
'float16': torch.HalfTensor,
'int64': torch.LongTensor,
'int32': torch.IntTensor,
'int16': torch.ShortTensor,
'int8': torch.CharTensor,
'uint8': torch.ByteTensor,
}
def default_collate(batch):
This file has been truncated. show original