I would like to find the indexes of certain targets and create a new data and target for my CIFAR10 dataset. This is how I’ve done it but I get the error.
class _CIFAR_Split(data.Dataset):
def __init__(self, root, targets, split='train', transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.split = split
if self.split == 'train':
# pdb.set_trace()
self._data = torchvision.datasets.CIFAR10(root=self.root, train=True, download=download, transform=None)
# inputs to full_like aren't accepted. First input must be a torch.tensor
t = torch.Tensor(self._data.targets)
idx = torch.full_like(t, 0, dtype=bool)
# idx is a tensor in boolean type
# targets = [0,1]
for target in targets:
idx = (idx | (self._data.targets == target))
self._data.targets = self._data.targets[idx]
self._data.data = self._data.data[idx]
I get the error in like self._data.targets = self._data.targets[idx]
.
I tried running in two environments but I get the same error.
First environment: Python 3.6.9, Pytorch 1.4.0, torchvision 0.5
Second environment: python 3.10, pytorch 1.11.0, torchvision 0.12
Is there a way to make this work? I think one possibility is to use np.where() but I would like to know if I can do any type conversions so this works. Thank you in advance