Weighted Cross Entropy for each sample in batch

I’m working on a problem that requires cross entropy loss in the form of a reconstruction loss. I would like to weight the loss for each sample in the mini-batch differently. The current API for cross entropy loss only allows weights of shape C. I would like to pass in a weight matrix of shape batch_size , C so that each sample is weighted differently.

Is there a way to do this? The only workaround I can think of is to evaluate the loss for each sample in the mini-batch and pass in a new set of weights each time.

1 Like

Sorry I don’t really understand your question.
Are you saying that for a batch N, you want to get the loss between each data point in N and it’s corresponding prediction?

I would like to provide custom weights for each sample that I evaluate the loss for in a batch.

At the moment I can do:

F.cross_entropy(y_pred, y_true, weight=weight)

and this will use a single set of weights for the entire batch.

What I would like to do is pass in a set of weights for each sample in the batch

A set of weights for each sample?

If you are giving a set of weights for each sample in a batch, then what is the point of generalizing?

I mean you could always just set your batch size to one but I doubt the model will perform well at inference.

I’m using a denoising autoencoder and wanted to use a weighted loss function to emphasise the loss for the corrupted components of the data. A batch size of 1 could work but the model would perform poorly.

The only solution I can think of is to do something like:

loss = []
for i, element in enumerate(batch):
    # calculate weights
    l = F.cross_entropy(y_pred[i], y_true[i], weight=weights[i])
    loss.append(l)

If each sample had its own weight, then ur model won’t be able to generalize properly on data that wasn’t part of the training data.

It’s like having a data points in a graph that make up a straight line. You would want to get the slope (m) of the straight line after making the plot of the entire data points, rather than finding m for each and every single point.

Well this could work.
Have you tried it ?

but the weights[ i ] might give an error because the weights are meant to be common for each point

Hi. You can probably do with reduction ‘none’. Something like this should probably work.
criterion = CrossEntropyLoss(reduction='none')
loss = criterion(network_out, label) --> shape batch_size x 1
loss = loss * batch_weights
loss = torcn.mean(loss)

2 Likes

That seems like a much more elegant solution!

1 Like

Did you see improvements doing this? I don’t… Made this thread - Clip/limit the loss for outlier samples in a batch