Different training loss profile when using MNIST dataset from txt file vs torchvision dataset

I’m new to pytorch. I’m facing a strange issue with respect to using my own custom MNIST dataset.

In order to debug the issue, I’m first comparing MNIST from torchvision datasets against MNIST from a txt file (txt file was created from a numpy array).

I apply the same transformations as used in torchvision datasets; which seems to be ToTensor() and appropriate casting to types that the model expects.

Basic statistics indicate that the two datasets are identical in terms of min, max, mean, and variance, as well as shape:

torchvision dataset

image

txt dataset
image

When I train the model on torchvision dataset, the training loss goes down rapidly; while when training on the txt dataset, it gets stuck;

Any idea what could be happening here? Thank you so much for any pointers you can provide.

While the stats look good, the shape looks a bit strange.
I would assume that the number of samples would be stored in dim0 followed by the height and width.
However, somehow the shape is permuted.

Could you get a single sample from both datasets and visualize it to check, if the pixels might be interleaved?
Something like this should work:

import matplotlib.pyplot as plt

data, _ = dataset[0]
print(data.shape) # should be [28, 28]
plt.imshow(data.numpy())

Hi ptrblck,

Thank you for your response.

Yeah, I’ve noticed that, not sure what causes the shape transformation, I looked at the source code for ToTensor() suspecting that it may be the culprit.


filename = 'RAW_MNIST-train.out'
RAW_MNIST = np.loadtxt(filename, delimiter=',')

# input_data = RAW_MNIST
input_data = RAW_MNIST

X = np.array(input_data[:,1:], np.uint8, copy=False) # need np.uint8 for ToTensor()

y = input_data[:,0].astype(np.int)
y = np.array([np.array(y[i]) for i in range(y.shape[0])])

train_targets = torch.from_numpy(y)


# X_train = X.reshape((28, 28, X.shape[0])) 
X_train = X.reshape((X.shape[0], 28, 28))
train_data = data_transform(X_train).float() 

print(train_data.shape) # torch.Size([28, 60000, 28])

# Create TensorDataset
train_set = TensorDataset(train_data,train_targets) # returns an error related to mismatched dims[0] of data and targets

# Create DataLoader
dataloader_args = dict(shuffle=True, batch_size=64,num_workers=0, pin_memory=False)
train_loader = dataloader.DataLoader(train_set, **dataloader_args)

The training loss profile in my post above was based on the following:

# notice the need to reshape X


X_train = X.reshape((28, 28, X.shape[0]))
# X_train = X.reshape((X.shape[0], 28, 28))
train_data = data_transform(X_train).float() 

print(train_data.shape, train_targets.shape)

# Create TensorDataset
train_set = TensorDataset(train_data,train_targets) # create your datset

# Create DataLoader
dataloader_args = dict(shuffle=True, batch_size=64,num_workers=0, pin_memory=False)
train_loader = dataloader.DataLoader(train_set, **dataloader_args)

Additionally, for the first block of code above;

train_data[0]

# returns
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

VS for the second block;

train_data[0]

#returns
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.0000,
         0.7020, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9098, 0.4745, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0784,
         0.9961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5922, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.9961, 1.0000, 0.9059, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.9490,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7725, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.4471, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.9804, 0.0000, 0.8824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0863,
         0.0000, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.7490, 0.9882, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.7961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9961, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922,
         0.0000, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.9922, 0.9882, 0.9922, 0.0000, 0.0000, 0.0000,
         0.0000, 0.9882, 0.0000, 0.9922, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000,
         0.5647, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0627, 0.0000, 0.9961, 0.0000, 0.0000, 0.0000,
         0.0000, 0.9255, 0.0000, 0.2941, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.5569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6039, 0.9098, 0.0000,
         0.0000],
        [0.0000, 0.6745, 0.0000, 0.2353, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.9961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9882, 0.0000, 0.3412,
         0.0000],
        [0.0000, 0.0000, 0.0863, 0.9882, 0.9882, 0.7412, 0.0000, 0.0000, 0.0000,
         0.3137, 0.0000, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0000,
         1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9882,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.9098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.8588, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0353, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.0471,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.1922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.8314, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.4510, 0.0000, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.1451, 0.0000, 0.0000, 0.0000, 0.0471, 0.9961,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000, 0.0000, 0.0000, 0.6431,
         0.0000, 0.0000, 0.0000, 0.9137, 0.0000, 0.0000, 0.0000, 0.0706, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9882, 0.0000, 0.0000, 0.0000,
         0.9882, 0.0000, 0.4314, 0.0000, 0.0000, 0.0000, 0.0000, 0.7961, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.4549, 0.0000, 0.9922, 0.0000, 0.9882,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5647, 0.0000, 0.0000, 0.0000,
         0.6627, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3882, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8902, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.3020, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.8627, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.1373, 0.0706, 0.0000, 0.0000, 0.0000, 0.0000, 0.0824, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.9922, 0.2667, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.2196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000, 0.9922,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.4000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 1.0000, 0.0000, 0.9255, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922,
         0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.2941, 0.0000, 0.9922, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.9961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8824,
         0.4667, 0.8235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.9922, 0.1294, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.9882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.1020, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.9882, 0.6353, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.9961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.9333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6039, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.2588, 0.9922, 0.4863, 0.0000, 0.0000, 0.0000,
         0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7529, 0.0000,
         0.9882, 0.4235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.9922, 0.9647, 0.0000, 0.0000, 0.0000, 0.0000,
         0.9922, 0.0000, 0.9020, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7294, 0.6039, 0.7216,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.7373, 0.0000, 0.0000, 0.0000, 0.0000,
         0.1529, 0.0392, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.1647, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1412, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.6353, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.7490, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5882, 0.9882,
         0.0000, 0.9529, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.7647, 0.3216, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.2863, 0.0000, 0.0000, 0.0000, 0.0000, 0.1137,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000, 0.5333, 0.0000,
         0.0000]])

Noticing the difference between the two print statements (intuiting it had something to do with the difference in training loss profile), and after some debugging time, I ended up creating my own custom Dataset class. I still don’t know why the first code block changes the way the tensor is printed (which is similar to the way a tensor from the torchvision dataset is printed), or why it changes the shape (which triggers the errors in the TensorDatset() line). I resorted to creating a custom Dataset class to mimic the way torchvision dataset is created (looking at the source code for that).

class CustomMNIST(Dataset):
    """CustomMNIST"""

    def __init__(self, input_arr, transform=None):
        
        """
        Args:
            input_array (np_array): data_array
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.input_arr = input_arr
        self.transform = transform
        X = np.array(self.input_arr[:,1:], np.uint8, copy=False)
        self.images = X.reshape((X.shape[0], 28, 28))
        y = self.input_arr[:,0].astype(np.int)
        self.target = torch.from_numpy(np.array([np.array(y[i]) for i in range(y.shape[0])]))


    def __len__(self):
        return len(self.input_arr)

    def __getitem__(self, index):
        image = self.images[index]
        if self.transform:
            image = self.transform(image).float()
        label = self.target[index]
        return (image, label)

My guess is that I’m missing something very fundamental here. Perhaps when the transformation is done in the first code block above, broadcasting is done differently? Or maybe something more low-level in terms of memory and how tensors are stored?