TensorDataset() Dimension Mismatch

Hi Everyone,

I’m using pytorch’s MNIST dataset and trying to understand how TensorDataset() works. My intention is to unpack the MNIST dataset into data and label tensors and then run some operations on them and then put them back together using the TensorDataset() .

Before I can play with the data tensor, I just wanted to see if I can make them look like the original datasets.MNIST using TensorDataset() but it seems like my Dimensions are not matching. It seems like I am losing the channel dimension. How can I get the dimensions to agree with the original dataset if I unpacked and then repacked the dataset with TensorDataset() ?

test_mnist = datasets.MNIST('data', train=False, download=True,
                           transforms.Normalize((0.1307,), (0.3081,))

# Unpacked Data
rt_mnist = torch.utils.data.TensorDataset(test_mnist.data, test_mnist.targets)  
test_loader1 = torch.utils.data.DataLoader(rt_mnist, batch_size=64, shuffle=True)

# Original Data
test_loader2 = torch.utils.data.DataLoader(test_mnist,
        batch_size=64, shuffle=True)

dataiter = iter(test_loader1)
images1, labels1 = dataiter.next()
dataiter = iter(test_loader2)
images2, labels2 = dataiter.next()

print(images1.shape, images2.shape)

torch.Size([64, 28, 28]) torch.Size([64, 1, 28, 28])

It seems that TensorDataset will squeeze the dimension of the tensor_data.

I tried to write a dummy code

tensor_data = torch.randn(10,).unsqueeze(0).unsqueeze(0)
custom_dataset = torch.utils.data.TensorDataset(tensor_data)
>> [1, 10]

Thanks for you reply. I think I got it!
I just need to reshape the “data.tensor” , so just updating the following line works:

rt_mnist = torch.utils.data.TensorDataset(test_mnist.data.unsqueeze(1), test_mnist.targets)