Trouble with ConcatDataset()

Hi Everyone,

I’m trying to understand the ConcatDataset() function. I’m running this example code from here:

The only change I’m making here is this:

train_mnist = datasets.MNIST('data', download=True, train=True, transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,)),
                              ]))
new_mnist = torch.utils.data.TensorDataset(train_mnist.data.unsqueeze(1), train_mnist.targets)
trainset = torch.utils.data.ConcatDataset([new_mnist]) 
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True)

Therefore, instead of passing the train_mnist to the DataLoader - I’m creating a tensorDatset and then using ConcatDataset to see if I get the same train/test behavior. Surprisingly, the training/testing behavior is very different - instead of getting high 90% accuracy in testing, testing accuracy drops to 10%. Something is definitely wrong! Anyone knows whats going on?

The transformation will be lost, if you pass the train_mnist.data directly to TensorDataset as it is applied in the __getitem__ method for each .data sample. You could try to normalize the data inplace before passing it to the TensorDataset.