Hello, I am using squeeze() before calculating loss. My code is given below and giving an error ValueError: Target size (torch.Size([1, 256, 383])) must be the same as input size (torch.Size([1, 1, 256, 383]))
cross_entropy() and binary_cross_entropy_with_logits take targets (your y_label) with differing shapes, and, in fact, with
different numbers of dimensions.
If input (your y_hat) has shape [nBatch, nClass, height, width], cross_entropy_loss() expects a target with shape [nBatch, height, width] (no nClass dimension), whose values
are integer class labels that run from 0 to nClass - 1, while binary_cross_entropy_with_logits() expects a target of shape [nBatch, nClass, height, width] (the same shape as the input),
whose values are probabilities, that is floats that run from 0.0 to 1.0
(and can be exactly 0.0 and 1.0 in the most common use case).
(Note, using binary_cross_entropy_with_logits() with an nClass
dimension is appropriate for the multi-label, multi-class use case. In
the more common single-label, two-class (binary) use case, you would
not use an nClass dimension, neither nClass = 1 nor nClass = 2.)
As a side comment, this isn’t consistent with:
In your y, axis = 1 (usually called dim) has size 3, so squeeze()
does nothing, and y_label will have shape [1, 3, 256, 383].
(If y had, instead, shape [1, 1, 256, 383], squeeze (axis = 1)
would, in fact, remove the axis = 1 singleton dimension.)
A comment on your approach: What I think you’re trying to do is
switch to binary_cross_entropy_with_logits when you have
a two-class (binary) problem. While, in principle, your scheme could
be made to work with some complexity and attention to detail, I
don’t think it’s worth the bother.
If you know that you have a binary problem, just build that into your
code. The output of your model (y_hat) and your target will both
have shape [nBatch, height, width] (note no nClass dimension),
and you will use binary_cross_entropy_with_logits().
If you want to write code that can be used for a multi-class problem,
but will work for various numbers of classes, including nClass = 2,
just implement the general multi-class solution using cross_entropy_loss(). (Again, in this case, the output of your
model should have shape [nBatch, nClass, height, width],
and your target should have shape [nBatch, height, width].)
cross_entropy_loss() works just fine for the nClass = 2 (binary)
case, and any minor efficiency gains you might get by switching
to binary_cross_entropy_with_logits() when nClass = 2
just aren’t worth the hassle.
It is a BAD idea to use MSELoss for classification. Here is the
reason: MSELoss cares about how far the prediction is from the
target. But in a classification problem you generally don’t have
that notion of distance.
If a tree is 60 years old, I have done better if I predict it to be 55
years old than if I predict it to be 40 years old. Both predictions
are wrong, but the 40-years-old prediction is worse. MSELoss
takes this into account.
But if I classify a bird as fish, that’s wrong, but no better or worse
than classifying the bird as a reptile. There’s no sense of distance
or some concept of a class being closer to some classes and farther
from others. (Conceptually there could be, but there is not in the
typical pytorch classification use case.)
If you want to use MSELoss for classification, you would convert
your integer class labels to floats (so target would be a FloatTensor
of shape [nBatch], with no nClass dimension). Your model would
have a single output so it would output batch predictions of shape [nBatch, 1] which you would squeeze() to get shape [nBatch].
Your (floating-point) target would have values running from 0.0
to nClass - 1. If your target value were, say, 2.0, then a predicted
value of 2.0 would be perfect, and MSELoss would return 0.0. A
prediction of 1.95 or 2.05 would be quite good, while a prediction
of 0.0 (or even -1.0) or 4.0 would be worse.
But for a classification problem, for a target of 2.0, would you
really consider a prediction of 0.0 to be worse than a prediction of 1.0, or are they both just wrong because they both predict the wrong