Hello,
I am trying to train a model on the unlabeled split of the STL10 dataset. When running the following code I run into an error:
stl_transforms = transforms.Compose([transforms.ToTensor(),
transforms.ToPILImage(),
transforms.ToTensor()])
unsupervised_data = datasets.STL10(root = '../data', split = 'unlabeled', transform=stl_transforms, download=True)
unsupervised_train_loader = torch.utils.data.DataLoader(unsupervised_data,
batch_size=args.batch_size, shuffle=True)
for batch_idx, (data, _) in enumerate(unsupervised_train_loader):
In this case I get: TypeError: batch must contain tensors, numbers, dicts or lists; found <class ‘NoneType’>
Interestingly, if I set split to ‘train’ or ‘train+unlabeled’ , everything works well.
Any help would be appreciated!