I do not understand loos function

I am learning to re-train Resnet, with one more layer of 9 classes on top of final 1000.
I think I do not understand criterion function.

criterian = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = 0.0001, momentum = 0.9)
for epoch in range(20):
    running_loss  = 0.0
    for X, Y in training:
        optimizer.zero_grad()
        output = net(X)

        print(output)
        print(output.shape)
        print(Y)
        print(Y.shape)
        print(Y.argmax())
        loss = criterian(output, Y.argmax())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(epoch, ':', running_loss)

It seems legit for me (since I don’t know yet how to create batches I am providing image one-by-one).
I have output of tensor of shape 1x9 and target is a one element tensor, which specifies into which class it should be targeting.
However the output is:

tensor([[0.1639, 0.2125, 0.1067, 0.1212, 0.1434, 0.0613, 0.0512, 0.0225, 0.1172]],
       device='cuda:0', grad_fn=<SoftmaxBackward>)
torch.Size([1, 9])
tensor([0., 0., 1., 0., 0., 0., 0., 0., 0.], device='cuda:0')
torch.Size([9])
tensor(2, device='cuda:0')
Traceback (most recent call last):
  File "im9_train.py", line 173, in <module>
    loss = criterian(output, Y.argmax())
  File "/home/szandala/imagenet/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/szandala/imagenet/venv/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 961, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/home/szandala/imagenet/venv/lib/python3.8/site-packages/torch/nn/functional.py", line 2468, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/home/szandala/imagenet/venv/lib/python3.8/site-packages/torch/nn/functional.py", line 2260, in nll_loss
    if input.size(0) != target.size(0):

What am I doing wrong?

Hi Szandala!

The short story is that – for separate reasons – you are passing an
improper input and target to your CrossEntropyLoss criterion
function.

This is appropriate. You have a batch size of 1, with 9 predictions
for a nine-class classification problem.

This isn’t actually true.

Your Y is a one-dimensional tensor of length nine. Y.argmax() is a
zero-dimensional tensor. It is a “scalar” that holds a single value, but
it is not a one-dimensional tensor of length one.

You should feed Y.argmax().unsqueeze (dim = 0) into your loss
criterion. This will have shape = [1], and so will have the necessary
batch size of one.

This tells me that your model has a final Softmax layer. This is incorrect
for CrossEntropyLoss, which has, in effect, Softmax already built in.

Most likely, you should have the final output of your model be the output
of your last Linear layer (which should have out_features = 9).

Good luck.

K. Frank

Thank You.
Tho I am surprised I cannot use Softmax.
SO I should remove now Softmax and append it when I will be using the network?

Hi Szandala!

As per the documentation, CrossEntropyLoss has Softmax built in.
This is for reasons of numerical stability and convenience. If you had
a Softmax at the end of your network, you would be running your
predictions through Softmax twice, which is not what you want.

You should remove the final Softmax from your model.

But, in general, you do not need Softmax for anything, so you normally
won’t want to append it when you are “using the network.”

Your network will emit raw-score logits as its predictions. If you pass
them through Softmax you will get probabilities. But the logits and
probabilities contain the same information, so, in general, you can work
directly with the logits and not need to convert them to probabilities.

Note that the argmax() of a set of logits will be the same as that of
the corresponding set of probabilities. So if you make a firm “class-X”
prediction by taking the argmax() of your probabilities, your can
equivalently take the argmax() of the logits.

Best.

K. Frank

Now got it, thank You very much!