Weights in BCEWithLogitsLoss, but with 'weight' instead of 'pos_weight'

I’m looking how to do class weighting using BCEWithLogitsLoss.

https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html

The example on how to use `pos_weight` seems clear to me. If there are 3x more negative samples than positive samples, then you can set pos_weight=3

Does the `weight` parameter do the same thing?

Say that I set it `weight=torch.tensor([1, 3])`. Is that the same thing as `pos_weight=3`

Also, is `weight` normalized? Is `weight=torch.tensor([1, 3])` the same as `weight=torch.tensor([3, 9])`, or are they different in how they affect the magnitude of the loss?

No, the `weight` argument will apply the weight to each sample of the batch.

1 Like

Thanks ptrblack!

I have one more question, about the example

target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10
output = torch.full([10, 64], 1.5) # A prediction (logit)
pos_weight = torch.ones([64]) # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target) # -log(sigmoid(1.5))
tensor(0.2014)

This seems like it could be one of two scenarios. Either a single binary class, with 64 examples per batch, and a batch size of 10. Or, it could be 64 binary classes, with one example per batch, and a batch size of 10.

How would `torch.nn.BCEWithLogitsLoss` know which scenario this is? If I put in a single value, pos_weight=torch.tensor(3), would it automatically assume it’s the former from the length of the tensor? And if I use pos_weight = torch.ones([64]*3, `BCEWithLogitsLoss` will assume it’s the former?

`pos_weight = torch.ones([64]*3)` won’t work in this example and will raise:

``````RuntimeError: The size of tensor a (64) must match the size of tensor b (10) at non-singleton dimension 1
``````
``````pos_weight=torch.tensor(3)
``````

will broadcast the `pos_weight` and you will get the same loss value if you use `torch.tensor(1.)`.

The example uses 10 samples (`batch_size=10`) where each sample contains 64 classes which can be active or inactive and can thus be seen as a multi-label classification use case.

1 Like