Dynamically weighing Binary Cross Entropy with Logits based on batch

is it possible to pass your own weights into the loss function during each batch such that it affects how the loss is calculated? Note: I don’t want to use the same weights for every batch (which is the weight argument), I want to weigh the loss from each output neuron dynamically based on a rule from the ground truth labels.

So for example, a 2 neuron final layer can have loss weighing [[1,0.5],[0,0.5]] for the first batch of 2, and the second batch of 2 can have [[0,0],[1,1]]… and so on.
and the code might look something like this in the training loop:

target = torch.randn([2, 2], dtype=torch.float32) #placeholder, will be taken from data loader
weights = create_weights_from_target(target)
output = model(inputs)
criterion = torch.nn.BCEWithLogitsLoss()
criterion(output, target, weights=weights) #extra arg to specify weights for each batch

Short of reimplementing the whole loss function, is there a better way to do this?
Please let me know if you need me to clarify anything. Thank you very much!!!

1 Like

If you use the functional API, you would avoid recreating the instance of the loss function.
Alternatively, you could pass reduction='none' to the criterion and multiply the loss output manually with your weights. After the multiplication you could take the mean or normalize the loss values as you wish.

2 Likes