Changing nll_loss to mse_loss but get The size of tensor a (64) must match the size of tensor b (10) at non-singleton dimension 1

Here is my code.

def simplest_loss(output, target):
    #loss = target-output
    loss = torch.mean((output - target)**2)
    return loss


network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
					  momentum=momentum)

train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

def train(epoch):
	network.train()
	for batch_idx, (data, target) in enumerate(train_loader):
		optimizer.zero_grad()
		output = network(data)
		#loss = F.nll_loss(output, target)

		#loss = simplest_loss(output,target)
		loss = F.mse_loss(target,output);
		#loss = torch.nn.MSELoss(target,output);

		print("output")
		print(output)
		print("target")
		print(target)
		print("loss")
		print(loss)

		loss.backward()
		optimizer.step()
		if batch_idx % log_interval == 0:
			print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
				epoch, batch_idx * len(data), len(train_loader.dataset),
				100. * batch_idx / len(train_loader), loss.item()))
			train_losses.append(loss.item())
			train_counter.append(
				(batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
			torch.save(network.state_dict(), 'results/model.pth')
			torch.save(optimizer.state_dict(), 'results/optimizer.pth')

The expected model output and target shapes differ for nn.NLLLoss and nn.MSELoss as explained in their docs.
I don’t know what your actual use case is but the current error in nn.MSELoss is raised as both the model output and target should have the same shape.

Turns out that the output had 10 by 64 and the target needed to be transformed into a hot 1 encoding from 64 to 10 by 64.

I’m still not sure how NLLLoss works after reading your link, but I don’t need to know at the moment but might visit this again later.

Thanks anyway.

Double check these shapes, since and output in the shape [10, 64] should not work with a target of [64] using nn.NLLLoss.