Hi ptrblck,
Thank you for your response.
Yeah, I’ve noticed that, not sure what causes the shape transformation, I looked at the source code for ToTensor() suspecting that it may be the culprit.
filename = 'RAW_MNIST-train.out'
RAW_MNIST = np.loadtxt(filename, delimiter=',')
# input_data = RAW_MNIST
input_data = RAW_MNIST
X = np.array(input_data[:,1:], np.uint8, copy=False) # need np.uint8 for ToTensor()
y = input_data[:,0].astype(np.int)
y = np.array([np.array(y[i]) for i in range(y.shape[0])])
train_targets = torch.from_numpy(y)
# X_train = X.reshape((28, 28, X.shape[0]))
X_train = X.reshape((X.shape[0], 28, 28))
train_data = data_transform(X_train).float()
print(train_data.shape) # torch.Size([28, 60000, 28])
# Create TensorDataset
train_set = TensorDataset(train_data,train_targets) # returns an error related to mismatched dims[0] of data and targets
# Create DataLoader
dataloader_args = dict(shuffle=True, batch_size=64,num_workers=0, pin_memory=False)
train_loader = dataloader.DataLoader(train_set, **dataloader_args)
The training loss profile in my post above was based on the following:
# notice the need to reshape X
X_train = X.reshape((28, 28, X.shape[0]))
# X_train = X.reshape((X.shape[0], 28, 28))
train_data = data_transform(X_train).float()
print(train_data.shape, train_targets.shape)
# Create TensorDataset
train_set = TensorDataset(train_data,train_targets) # create your datset
# Create DataLoader
dataloader_args = dict(shuffle=True, batch_size=64,num_workers=0, pin_memory=False)
train_loader = dataloader.DataLoader(train_set, **dataloader_args)
Additionally, for the first block of code above;
train_data[0]
# returns
tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])
VS for the second block;
train_data[0]
#returns
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.0000,
0.7020, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9098, 0.4745, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0784,
0.9961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5922, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.9961, 1.0000, 0.9059, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.9490,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7725, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.4471, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.9804, 0.0000, 0.8824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0863,
0.0000, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.7490, 0.9882, 0.0000,
0.0000],
[0.0000, 0.0000, 0.7961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9961, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922,
0.0000, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.9922, 0.9882, 0.9922, 0.0000, 0.0000, 0.0000,
0.0000, 0.9882, 0.0000, 0.9922, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000,
0.5647, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0627, 0.0000, 0.9961, 0.0000, 0.0000, 0.0000,
0.0000, 0.9255, 0.0000, 0.2941, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.5569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6039, 0.9098, 0.0000,
0.0000],
[0.0000, 0.6745, 0.0000, 0.2353, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.9961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9882, 0.0000, 0.3412,
0.0000],
[0.0000, 0.0000, 0.0863, 0.9882, 0.9882, 0.7412, 0.0000, 0.0000, 0.0000,
0.3137, 0.0000, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0000,
1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9882,
0.0000],
[0.0000, 0.0000, 0.0000, 0.9098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.8588, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0353, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.0471,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.1922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.8314, 0.0000,
0.0000],
[0.0000, 0.0000, 0.4510, 0.0000, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.1451, 0.0000, 0.0000, 0.0000, 0.0471, 0.9961,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000, 0.0000, 0.0000, 0.6431,
0.0000, 0.0000, 0.0000, 0.9137, 0.0000, 0.0000, 0.0000, 0.0706, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9882, 0.0000, 0.0000, 0.0000,
0.9882, 0.0000, 0.4314, 0.0000, 0.0000, 0.0000, 0.0000, 0.7961, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.4549, 0.0000, 0.9922, 0.0000, 0.9882,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5647, 0.0000, 0.0000, 0.0000,
0.6627, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3882, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8902, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.3020, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.8627, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.1373, 0.0706, 0.0000, 0.0000, 0.0000, 0.0000, 0.0824, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.9922, 0.2667, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.2196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000, 0.9922,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.4000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 1.0000, 0.0000, 0.9255, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922,
0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.2941, 0.0000, 0.9922, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8824,
0.4667, 0.8235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.9922, 0.1294, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.9882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.1020, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.9882, 0.6353, 0.9882, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.9961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.9333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6039, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.2588, 0.9922, 0.4863, 0.0000, 0.0000, 0.0000,
0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7529, 0.0000,
0.9882, 0.4235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.9922, 0.9647, 0.0000, 0.0000, 0.0000, 0.0000,
0.9922, 0.0000, 0.9020, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7294, 0.6039, 0.7216,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.7373, 0.0000, 0.0000, 0.0000, 0.0000,
0.1529, 0.0392, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.1647, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1412, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.6353, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.7490, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5882, 0.9882,
0.0000, 0.9529, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.7647, 0.3216, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.2863, 0.0000, 0.0000, 0.0000, 0.0000, 0.1137,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9922, 0.0000, 0.5333, 0.0000,
0.0000]])
Noticing the difference between the two print statements (intuiting it had something to do with the difference in training loss profile), and after some debugging time, I ended up creating my own custom Dataset
class. I still don’t know why the first code block changes the way the tensor is printed (which is similar to the way a tensor from the torchvision
dataset is printed), or why it changes the shape (which triggers the errors in the TensorDatset()
line). I resorted to creating a custom Dataset
class to mimic the way torchvision
dataset is created (looking at the source code for that).
class CustomMNIST(Dataset):
"""CustomMNIST"""
def __init__(self, input_arr, transform=None):
"""
Args:
input_array (np_array): data_array
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.input_arr = input_arr
self.transform = transform
X = np.array(self.input_arr[:,1:], np.uint8, copy=False)
self.images = X.reshape((X.shape[0], 28, 28))
y = self.input_arr[:,0].astype(np.int)
self.target = torch.from_numpy(np.array([np.array(y[i]) for i in range(y.shape[0])]))
def __len__(self):
return len(self.input_arr)
def __getitem__(self, index):
image = self.images[index]
if self.transform:
image = self.transform(image).float()
label = self.target[index]
return (image, label)
My guess is that I’m missing something very fundamental here. Perhaps when the transformation is done in the first code block above, broadcasting is done differently? Or maybe something more low-level in terms of memory and how tensors are stored?