I checked the source of the Subset class (torch.utils.data.dataset — PyTorch 1.13 documentation) and saw that passing a list to a subset object like subset.getitem([1,2,3]) should work, but actually it’s not:
…/dataset.py", line 272, in getitem
return self.dataset[self.indices[idx]]
TypeError: list indices must be integers or slices, not list
It seems to work for me:
dataset = TensorDataset(torch.arange(10))
subset = Subset(dataset, indices=list(range(0, 10, 2)))
for data in subset:
print(data)
# (tensor(0),)
# (tensor(2),)
# (tensor(4),)
# (tensor(6),)
# (tensor(8),)
data = subset.__getitem__([0, 1, 3])
print(data)
# (tensor([0, 2, 6]),)
data = subset[[0, 1, 3]]
print(data)
# (tensor([0, 2, 6]),)
data = subset[3]
print(data)
# (tensor(6),)
This feature was added in June 2021 in this PR so you might need to update PyTorch in case you are using an older release.
Sorry, I was indeed using an older Pytorch version (1.7.0), thanks for the reply