Weights in BCEWithLogitsLoss, but with 'weight' instead of 'pos_weight'

I’m looking how to do class weighting using BCEWithLogitsLoss.

https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html

The example on how to use pos_weight seems clear to me. If there are 3x more negative samples than positive samples, then you can set pos_weight=3

Does the weight parameter do the same thing?

Say that I set it weight=torch.tensor([1, 3]). Is that the same thing as pos_weight=3

Also, is weight normalized? Is weight=torch.tensor([1, 3]) the same as weight=torch.tensor([3, 9]), or are they different in how they affect the magnitude of the loss?

No, the weight argument will apply the weight to each sample of the batch.

1 Like

Thanks ptrblack!

I have one more question, about the example

target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10
output = torch.full([10, 64], 1.5) # A prediction (logit)
pos_weight = torch.ones([64]) # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target) # -log(sigmoid(1.5))
tensor(0.2014)

This seems like it could be one of two scenarios. Either a single binary class, with 64 examples per batch, and a batch size of 10. Or, it could be 64 binary classes, with one example per batch, and a batch size of 10.

How would torch.nn.BCEWithLogitsLoss know which scenario this is? If I put in a single value, pos_weight=torch.tensor(3), would it automatically assume it’s the former from the length of the tensor? And if I use pos_weight = torch.ones([64]*3, BCEWithLogitsLoss will assume it’s the former?

pos_weight = torch.ones([64]*3) won’t work in this example and will raise:

RuntimeError: The size of tensor a (64) must match the size of tensor b (10) at non-singleton dimension 1
pos_weight=torch.tensor(3)

will broadcast the pos_weight and you will get the same loss value if you use torch.tensor(1.).

The example uses 10 samples (batch_size=10) where each sample contains 64 classes which can be active or inactive and can thus be seen as a multi-label classification use case.

1 Like