Hi,
Here is the problem. MNIST object does not have transform and target_transform. You need to call dataset attribute of MNIST then call transform etc.
Here you using dataset_valid, but you are printing shape of dataset_train.
And also, it seems for predictions, you have used batch_size=5000 which lead to 5000 predictions but you are using labels of entire dataset which is 10000.
And I think you are not looping over batches and want to feed entire dataset to network, in this case change batch_size= dataset_valid.dataset.data.shape[0] in this line:
Yes sorry, I coiped the wrong line but the output I posted was from dataset_valid.
And also, it seems for predictions , you have used batch_size=5000 which lead to 5000 predictions but you are using labels of entire dataset which is 10000.
So, If I’ve understood well, random_split splits the input data but not the labels, right?
In this case, what should i do?
Sorry to bother you with maybe dumb questions, I’m new at PyTorch.
No, random_split is doing fine. The problem is the way you obtained the predictions. Can you share the code that generated those predictions?
Here you need to only pass targets with same size of batch.
For instance here is simple snippet:
for data, target in dataloader_valid:
# this is one batch in every iteration of loop
predictions = model(data) # size [batch_size, ...]
# now targets are also size [batch_size, ...]
But you are not using batch for labels, you passing entire labels out of loop
This is not a batch of data, it is your whole dataset.
Here is where i defined predictions (it is a multi-class classifier):
def test(dataset, dataloader):
# switch to test mode
net.eval()
# initialize predictions
predictions = torch.zeros(len(dataset), dtype=torch.int64)
sample_counter = 0
# do not accumulate gradients (faster)
with torch.no_grad():
# test all batches
for batch in dataloader:
# get data from dataloader [ignore labels/targets as they are not used in test mode]
inputs = batch[0]
# move data to device
inputs = inputs.to(device, non_blocking=True)
# forward pass
outputs = net(inputs)
# store predictions
outputs_max = torch.argmax(outputs, dim=1)
for output in outputs_max:
predictions[sample_counter] = output
sample_counter += 1
return predictions
I tought it was a problem with random_split becouse this problem has rised since I used that function.
Previosly I was using the original test set from MNIST with the same code and had no issue.
Thanks for helping.
You are assigning the label of each sample by looping over each prediction in each batch. This is not optimal at all and although you are using batches but at each iteration, you are using predictions to have all predicted labels. But still it should be fine and I cannot see any issue regarding size mismatch, maybe a mistake in passing arguments to methods?
This will do the same job:
dataloader_valid = torch.utils.data.DataLoader(dataset_valid,
batch_size=len(dataset_valid),
num_workers=0,
pin_memory=True)
def test(dataset, dataloader):
# switch to test mode
net.eval()
# do not accumulate gradients (faster)
with torch.no_grad():
# test all batches
for batch in dataloader:
# get data from dataloader [ignore labels/targets as they are not used in test mode]
inputs = batch[0]
# move data to device
inputs = inputs.to(device, non_blocking=True)
# forward pass
outputs = net(inputs)
# store predictions
predictions = torch.argmax(outputs, dim=1)
return predictions
Oh, I didn’t know it, thank you.
I arleady posted everything involved in the line where I get the error.
The problem is that dataset_valid.dataset.targets has a lenght of 10000, but I can’t understand why since i used random_split with a lenght of 5000
More explanation would be when you use random_split, it does not create two objects that have data, it just creates a list of indices that are randomly generated from original dataset.
So for instance, you can obtain original dataset by calling dataset_valid.dataset which is equal to first line of your code in first post. and it also have another attribute dataset_valid.indices which contains 5000 indices (as you provided) of randomly selected items in dataset.