Unbalanced multi-label classification loss

I have the following loss function:

class WeightedBCE(nn.Module):
    def __init__(self, pos_w, neg_w):
        super(WeightedBCE, self).__init__()
        pos_w = torch.tensor(pos_w, dtype=torch.float, requires_grad=False)
        neg_w = torch.tensor(neg_w, dtype=torch.float, requires_grad=False)
        self.register_buffer("pos_w", pos_w)
        self.register_buffer("neg_w", neg_w)
        self.eps = 1e-10
        return

    def forward(self, y_hat, t):
        loss = 0.
        for label in range(self.pos_w.shape[0]):
            pos_loss = -1. * torch.mean(t[:, label] * self.pos_w[label] * torch.log(y_hat[:, label] + self.eps))
            neg_loss = -1. * torch.mean((1. - t[:, label]) * self.neg_w[label] * torch.log(1. - y_hat[:, label] + self.eps))
            loss += pos_loss + neg_loss

        return loss

Where neg_w is:

tensor([0.2246, 0.0539, 0.0907, 0.0448, 0.2572, 0.0486, 0.0319, 0.0044, 0.3841,
        0.1113, 0.1222, 0.0656, 0.0283, 0.1016])

and pos_w is:

tensor([0.7754, 0.9461, 0.9093, 0.9552, 0.7428, 0.9514, 0.9681, 0.9956, 0.6159,
        0.8887, 0.8778, 0.9344, 0.9717, 0.8984])

My model output, is after sigmoid layer, but for some reason I can see any improvement in the custom hamming score during training:

[ Epoch 1/100 ] [ Batch 1/825 ] [ Loss: 1.840901 ] [Batch Hamming Score: 0.107143] [ Batch Time: 0:00:09.478559 ]

[ Epoch 1/100 ] [ Batch 101/825 ] [ Loss: 1.701975 ] [Batch Hamming Score: 0.102679] [ Batch Time: 0:00:01.401633 ]

[ Epoch 1/100 ] [ Batch 201/825 ] [ Loss: 1.664751 ] [Batch Hamming Score: 0.102679] [ Batch Time: 0:00:01.400213 ]

[ Epoch 1/100 ] [ Batch 301/825 ] [ Loss: 1.745564 ] [Batch Hamming Score: 0.119048] [ Batch Time: 0:00:01.403180 ]

[ Epoch 1/100 ] [ Batch 401/825 ] [ Loss: 1.792013 ] [Batch Hamming Score: 0.117560] [ Batch Time: 0:00:01.408394 ]

[ Epoch 1/100 ] [ Batch 501/825 ] [ Loss: 1.696599 ] [Batch Hamming Score: 0.104167] [ Batch Time: 0:00:01.404566 ]

[ Epoch 1/100 ] [ Batch 601/825 ] [ Loss: 1.706017 ] [Batch Hamming Score: 0.108631] [ Batch Time: 0:00:01.409217 ]

[ Epoch 1/100 ] [ Batch 701/825 ] [ Loss: 1.735794 ] [Batch Hamming Score: 0.107143] [ Batch Time: 0:00:01.409518 ]

[ Epoch 1/100 ] [ Batch 801/825 ] [ Loss: 1.657647 ] [Batch Hamming Score: 0.105655] [ Batch Time: 0:00:01.402675 ]

It’s stay that way all along the training.

My hamming score function is:

def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):
    '''
    Compute the Hamming score (a.k.a. label-based accuracy) for the multi-label case
    http://stackoverflow.com/q/32239577/395857
    '''
    acc_list = []
    for i in range(y_true.shape[0]):
        set_true = set( np.where(y_true[i])[0] )
        set_pred = set( np.where(y_pred[i])[0] )
        #print('\nset_true: {0}'.format(set_true))
        #print('set_pred: {0}'.format(set_pred))
        tmp_a = None
        if len(set_true) == 0 and len(set_pred) == 0:
            tmp_a = 1
        else:
            tmp_a = len(set_true.intersection(set_pred))/\
                    float( len(set_true.union(set_pred)) )
        #print('tmp_a: {0}'.format(tmp_a))
        acc_list.append(tmp_a)
    return np.mean(acc_list)

I wonder if you have any suggestions on how to improve my loss in some way?

I was thinking of using BCEWithLogitsLoss and remove the sigmoid layer from the model, but I was wondering what should be the weights vector in that case.

Hi David!

First, as an aside: I haven’t looked closely at your WeightedBCE code, but I
don’t see anything mathematically wrong with it. However, the for-loop over
label will slow things down in comparison with pytorch’s built-in BCELoss
(and the preferred BCEWithLogitsLoss).

Pytorch’s built-in BCEWithLogitsLoss supports positive-sample weighting
with its pos_weight constructor argument. Note that, for whatever reason,
BCELoss does not offer this pos_weight feature.

BCEWithLogitsLoss does not have a corresponding neg_weight option,
but its weight constructor argument gives overall class weights.

I haven’t double-checked the math, but I believe that for you to reproduce your
weighting scheme with BCEWithLogitsLoss, you would want to construct your
loss-function object as follows:

loss_fn = torch.nn.BCEWithLogitsLoss (weight = neg_w, pos_weight = pos_w / neg_w)

It does look like your loss may be drifting down, although I agree that I do
not see any improvement in your “Hamming Score.” However, you’ve only
trained for a single epoch, so I wouldn’t be too hasty about drawing any
conclusions. Also, some batches can be harder or easier than others (and
differently so for your loss vs. other figures of merit), so batch-to-batch
variance may mask a systematic trend.

You should certainly experiment with learning rate and momentum, and
possibly try other optimizers such as Adam if you are currently using SGD.

You should certainly prefer BCEWithLogitsLoss over BCELoss with sigmoid
for numerical-stability reasons. However, unless you are hitting issues with
numerical stability, the two will be the same, so switching won’t improve the
results of your training. (But you should switch to BCEWithLogitsLoss anyway
to guard against future stability issues.)

Best.

K. Frank

1 Like

Just another question, at the docs of BCEWithLogitsLoss the weight argument is batch weight.
I wasn’t sure whether to even use it. Did you had some other source for it?

Hi David!

I’m not sure how I concluded that BCEWithLogitsLoss's weight argument
could be used for class weights – maybe from experimenting. Anyway, I
agree with you that the documentation is inobvious at best, and probably
wrong in some details.

It appears that weight can be used both for batch and class weights. Note the
dimensions in the following examples that I think sheds some light on what is
going on:

>>> import torch
>>> torch.__version__
'1.7.1'
>>> p = 0.7 * torch.ones ((2, 3))
>>> t = torch.ones ((2, 3))
>>> w3 = 0.5 * torch.ones ((3,))
>>> w2 = 0.5 * torch.ones ((2,))
>>> w23 = 0.5 * torch.ones ((2,3))
>>> w32 = 0.5 * torch.ones ((3,2))
>>> torch.nn.BCEWithLogitsLoss() (p, t)
tensor(0.4032)
>>> torch.nn.BCEWithLogitsLoss (weight = w3) (p, t)
tensor(0.2016)
>>> torch.nn.BCEWithLogitsLoss (weight = w23) (p, t)
tensor(0.2016)
>>> torch.nn.BCEWithLogitsLoss (weight = w2) (p, t)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/user/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/user/miniconda3/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 629, in forward
    return F.binary_cross_entropy_with_logits(input, target,
  File "/home/user/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 2582, in binary_cross_entropy_with_logits
    return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1
>>> torch.nn.BCEWithLogitsLoss (weight = w32) (p, t)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/user/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/user/miniconda3/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 629, in forward
    return F.binary_cross_entropy_with_logits(input, target,
  File "/home/user/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 2582, in binary_cross_entropy_with_logits
    return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

Best.

K. Frank

1 Like