I’m trying to understand the ConcatDataset() function. I’m running this example code from here:
The only change I’m making here is this:
train_mnist = datasets.MNIST('data', download=True, train=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ])) new_mnist = torch.utils.data.TensorDataset(train_mnist.data.unsqueeze(1), train_mnist.targets) trainset = torch.utils.data.ConcatDataset([new_mnist]) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True)
Therefore, instead of passing the
train_mnist to the
DataLoader - I’m creating a
tensorDatset and then using
ConcatDataset to see if I get the same train/test behavior. Surprisingly, the training/testing behavior is very different - instead of getting high 90% accuracy in testing, testing accuracy drops to 10%. Something is definitely wrong! Anyone knows whats going on?