Training with replacement

Hello everyone!
I’m trying to train a deep neural network sampling my images with replacement. In other words, if I have for example a batch size of 10, and a data-set made of (let’s say) 1000 images, I would like to create 100 batches where each sample is randomly sampled from the whole data-set.

What I’m doing is the following:

random_train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./datasets/', train=True, transform=torchvision.transforms.ToTensor()),
        batch_size=batch_size,
        sampler=torch.utils.data.sampler.RandomSampler(torchvision.datasets.MNIST, replacement=True)
    )
for input_images, labels in iter(random_train_loader):
    # I do stuff

Sadly this build a

object of type 'type' has no len()

error.

I’ve probably misunderstood the documentation of RandomSampler, any suggestion?

Many thanks!

@Simone256 try like the following:

dataset = torchvision.datasets.MNIST('./datasets/', train=True, transform=torchvision.transforms.ToTensor())
random_train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=torch.utils.data.sampler.RandomSampler(dataset, replacement=True)
    )
for input_images, labels in iter(random_train_loader):

You should pass the instance of the dataset to the Random Sampler and not the class.

Ok sorry! Now it works, but I suppose using the full

torchvision.datasets.MNIST('./datasets/', train=True, transform=torchvision.transforms.ToTensor())

in the RandomSampler would have worked too, I’m sorry that was really a stupid error.

Just to ask one more thing, if I write something like this

data = torch.randn(10, 1, 28, 28)
target = torch.arange(10)

dataset = torch.utils.data.TensorDataset(data, target)

dataset “the same thing” as the dataset you define in your example? Where with “the same thing” I mean “another instance of the same class” (I suppose). So I can treat them equivalently?

Yes, a dataset instance should be an instance of a class derived from torch.utils.data.Dataset. So, torchvision.datasets.MNIST and torch.utils.data.TensorDataset are both derived from torch.utils.data.Dataset.

Please, see this docs for more details: