Trouble with ConcatDataset()

Hi Everyone,

I’m trying to understand the ConcatDataset() function. I’m running this example code from here:

The only change I’m making here is this:

train_mnist = datasets.MNIST('data', download=True, train=True, transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,)),
new_mnist =, train_mnist.targets)
trainset =[new_mnist]) 
train_loader =, batch_size=args.batch_size, shuffle=True)

Therefore, instead of passing the train_mnist to the DataLoader - I’m creating a tensorDatset and then using ConcatDataset to see if I get the same train/test behavior. Surprisingly, the training/testing behavior is very different - instead of getting high 90% accuracy in testing, testing accuracy drops to 10%. Something is definitely wrong! Anyone knows whats going on?

The transformation will be lost, if you pass the directly to TensorDataset as it is applied in the __getitem__ method for each .data sample. You could try to normalize the data inplace before passing it to the TensorDataset.