Modifying torchvision.datasets.CIFAR10

I want to modify train_data and train_labels of torchvision.datasets.CIFAR10 object:

trainset = torchvision.datasets.CIFAR10(root=’./data’, train=True, download=False)

When I modify trainset.train_labels and trainset.train_data, lengths of these list change. But, when I print len() function of CIFAR-10 class, it still prints the old train_data length. Therefore, when I later use this trainset with trainloader:

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True ) I get the following error:

samples = collate_fn([dataset[i] for i in batch_indices])
File “/usr/local/lib/python2.7/dist-packages/torchvision/datasets/cifar.py”, line 83, in getitem
img, target = self.train_data[index], self.train_labels[index]
IndexError: index 36238 is out of bounds for axis 0 with size 27500

Do you know what should I do to avoid this error?

The master branch of torchvision should have this fixed as per #211 there. If you look at the len() function it used to return a hardcoded 50,000, but now it returns the actual size of the train data. If you can’t get master for some reason you can do what I do and just subclass like so:

class CIFAR10(dset.CIFAR10):
    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)
2 Likes

I have the same problem even after upgrading to 0.2.0_1. But, subclass method works!

Thanks.