Treat MNIST data set as a regular data set

I want to perform some pre-selection on the data of the MNIST data set. However, I notice that the data set does not behave the way I would like to. I load it like

train_data = MNIST('./files/', train=True, download=True,
                    torchvision.transforms.Normalize((0.1307,), (0.3081,)), 

For example, trying to select the first five examples like


gives the error

ValueError: only one element tensors can be converted to Python scalars

So, to circumvene this, I create a new data set like

train_data = TensorDataset(torch.unsqueeze(, dim=1), train_data.targets)

which does not seem very elegant, but it makes my data set behave like I want it to. However, I notice that this way all the transformations I applied to the original data set are lost. What can I do to make the data set behave like I want to and not loose the transformations? I know there exists this Subset class, however I do not like to use that either since I want to also modify the data later. For the later transformation I cannot just create a transforms function since the transformation will also depend on the label which the transform argument to MNIST does not allow me to do.

You would need to index the Dataset directly e.g. via a loop:

for i in range(5):
    data, target = train_data[i]

instead of a slice.