I want to generate a random subset from DataLoader.
I have the following implementation:
train_indices = torch.LongTensor(args.train_set_size).random_(0, totalNumInTrainSet)
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,),
)
])),
batch_size=args.batch_size, sampler=SubsetRandomSampler(train_indices), ** kwargs)
This works, however generating the train indices results in duplicates, I want all random numbers generated to be unique. How would I do this efficiently ?
How about just pass shuffle
to the dataloader.
If you are playing with Samplers, you also could pass RandomSampler with replacement=False
to the dataloader, it is the same to the shuffle
method.
If I use the RandomSampler
with replacement=False
, I can’t specify the number of elements.
Message:
With replacement=False, num_samples should not be specified, since a random permute will be performed.
Will the RandomSampler
with replacement=True
and num_samples=n
, always return a sample with n unique elements?
1 Like
Because replacement=True
, so it will return the same sample probably.
With replacement=False
, RandomSampler will go through the whole dataset.
If you want to specify the number of samples. You could have a try on SubsetRandomSampler
.
Yep, that is what I have done above. I used SubsetRandomSampler
and generated random indexes to match the size I wanted. I wanted to know if there was an efficient way to generate unique random indexes to pass to SubsetRandomSampler .
You could try this
import numpy as np
train_indices = torch.from_numpy(np.random.choice(totalNumInTrainSet, size=(args.train_set_size,), replace=False))
2 Likes
Sorry for forgetting that.
I think you could have a try on torch.randperm() with fixed torch.manual_seed() and torch.multinomial() with sampling probability.
torch.manual_seed(1)
random_indices = torch.randperm(num_samples)
# it will return random permutation from 0 to num_samples-1
1 Like
I will use numpy until this or equivalent is merged in -> https://github.com/pytorch/pytorch/pull/18624 .