Modifying dataset after creating dataset instance


I would like to modify the dataset values while in learning. The code below is what I want: modifying the dataset value label in the optimization loop, though the dataset is not MNIST but my research-specific dataset. Of course, the following code won’t work as changing labels will not change datasets value iteself.

If anybody in the community try the similar thing, it would be really nice how to do this.

import torch
from torchvision import datasets
from torchvision import transforms

mnist_dataset = datasets.MNIST("/tmp", download=True, transform=transforms.ToTensor()) # whatever
mnist_small_dataset, _ =, [4, len(mnist_dataset)-4])
loader =, batch_size=2, shuffle=False)

for samples in loader:
    _, labels = samples
    # compute loss and backward() and optimize...
    labels[0] = 100000 # whatever

A naive workaround that actually does not work

Also, a workaround arose to me is like defining a dataset class wrapping the original dataset and return dataset index at the same time as below:

class WorkaroundDataset(
    def __init__(self, dataset):
        self._dataset = dataset

    def __len__(self):
        return len(self._dataset)

    def __getitem__(self, idx):
        return (*self._dataset[idx], idx)

Then, I tried to modify the dataset value by indexing, but it’s not possible as each element in sample is returned by tuple. I could directly modify the source code of pytorch so that it returns list, but it seems dirty and I don’t want to do that.

mnist_dataset = WorkaroundDataset(mnist_dataset)
mnist_small_dataset, _ =, [4, len(mnist_dataset)-4])
loader_workaround =, batch_size=2, shuffle=False)

for samples in loader_workaround:
    _, labels, idxes = samples
    # compute loss and backward() and optimize...
    for idx in idxes:
        mnist_dataset._dataset[idx][1] = 100000 # whatever

Hi @HiroIshida

Your workaround approach is quite ok.

Why don’t you try:

  1. to inherit WorkaroundDataset from datasets.MNIST,
  2. properly call a constructor of the parent class in the __init__,
  3. then just override __getitem__ to get item from a super's __getitem__
  4. and finally do your stuff with labels before the return of __getitem__