Hi Cameron!
The short answer is that BCEWithLogitsLoss
secretly doesn’t really
have a notion of “classes” inside of it. The weights you pass in just
match up with the elements of your predictions and targets. Whether
you have one or two or more “output channels” (or no “channels”
dimension) is irrelevant – it’s just one more dimension that
BCEWithLogitsLoss
processes in the same way.
Some more detail:
(A disclaimer: I might not be entirely right about this. It’s also
conceivable that the generality of the operation of BCEWithLogitsLoss
has increased over time, so my comments might not apply to older
versions of pytorch. I am currently working with version 1.7.1.)
The documentation for BCEWithLogitsLoss is somewhat misleading. It
talks about batches and classes and such, but it just has dimensions;
it doesn’t, per se, operate on batch dimensions and class dimensions,
and so on. The user is allowed to think in those terms, but that
doesn’t affect what BCEWithLogitsLoss
actually does.
Consider this example:
import torch
print (torch.__version__)
_ = torch.manual_seed (2021)
nBatch = 2
# consider a "pos_weighted" multi-label, 15-class use case:
nClass = 15
pred_class = torch.randn (nBatch, nClass)
targ_class = torch.rand (nBatch, nClass)
posw_class = torch.rand (nBatch, nClass)
loss_class = torch.nn.BCEWithLogitsLoss (pos_weight = posw_class) (pred_class, targ_class)
print ('loss_class =', loss_class)
# but is there any difference between 15 classes and 3 * 5 image pixels?
height = 3
width = 5
pred_image = pred_class.reshape (nBatch, height, width)
targ_image = targ_class.reshape (nBatch, height, width)
posw_image = posw_class.reshape (nBatch, height, width)
loss_image = torch.nn.BCEWithLogitsLoss (pos_weight = posw_image) (pred_image, targ_image)
print ('loss_image =', loss_image)
# and is the batch dimension treated specially?
loss_flat = torch.nn.BCEWithLogitsLoss (pos_weight = posw_class.flatten()) (pred_class.flatten(), targ_class.flatten())
print ('loss_flat =', loss_flat)
And here is its output:
1.7.1
loss_class = tensor(0.6958)
loss_image = tensor(0.6958)
loss_flat = tensor(0.6958)
The documentation hints at this in its “shape” section:
Shape:
Input: (N, *) where "*" means, any number of additional dimensions
Target: (N, *), same shape as the input
Output: scalar. If reduction is 'none', then (N, *), same shape as input.
But BCEWithLogitsLoss
doesn’t actually differentiate between the
batch dimension and the other dimensions. I think this part of the
documentation should simply read:
Shape:
Input: (*) where "*" means, any number of dimensions
(Use cases with “batches” and “classes” and “channels” and “images”
could be illustrated with examples, but the documentation should state
clearly that there is no operational difference.)
(As far as I can tell, input
and target
must have the shame shape
and pos_weight
must be broadcastable over input
and target
.
Also weight
enters a little differently into the per-element loss
function, but behaves the same as pos_weight
in terms of shapes
and dimensions.)
Best.
K. Frank