Error when indexing: TypeError: only integer tensors of a single element can be converted to an index

I would like to find the indexes of certain targets and create a new data and target for my CIFAR10 dataset. This is how I’ve done it but I get the error.

class _CIFAR_Split(data.Dataset):

def __init__(self, root, targets, split='train', transform=None,  download=False):
    self.root = os.path.expanduser(root)
    self.transform = transform
    self.split = split

    if self.split == 'train':
        # pdb.set_trace()
        self._data = torchvision.datasets.CIFAR10(root=self.root, train=True, download=download, transform=None)
        # inputs to full_like aren't accepted. First input must be a torch.tensor
        t = torch.Tensor(self._data.targets)
        idx = torch.full_like(t, 0, dtype=bool)
        # idx is a tensor in boolean type
        # targets = [0,1]
        for target in targets:
            idx = (idx | (self._data.targets == target))
        self._data.targets = self._data.targets[idx]
        self._data.data = self._data.data[idx]

I get the error in like self._data.targets = self._data.targets[idx].
I tried running in two environments but I get the same error.
First environment: Python 3.6.9, Pytorch 1.4.0, torchvision 0.5
Second environment: python 3.10, pytorch 1.11.0, torchvision 0.12

Is there a way to make this work? I think one possibility is to use np.where() but I would like to know if I can do any type conversions so this works. Thank you in advance :slight_smile:

If you have a bool index, you’d get only lots of the first and second items in your targets… I suggest using an int dtype index.

Also please note that the loop above does nothing but just picks the last element of the loop, i.e. target = 1.