Hi Everyone,
I’m trying to understand the ConcatDataset() function. I’m running this example code from here:
The only change I’m making here is this:
train_mnist = datasets.MNIST('data', download=True, train=True, transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]))
new_mnist = torch.utils.data.TensorDataset(train_mnist.data.unsqueeze(1), train_mnist.targets)
trainset = torch.utils.data.ConcatDataset([new_mnist])
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True)
Therefore, instead of passing the train_mnist
to the DataLoader
- I’m creating a tensorDatset
and then using ConcatDataset
to see if I get the same train/test behavior. Surprisingly, the training/testing behavior is very different - instead of getting high 90% accuracy in testing, testing accuracy drops to 10%. Something is definitely wrong! Anyone knows whats going on?