Hi, I have implemented a Variational AutoEncoder for Collaborative Filterting (user-item) and since I have very sparse data, where some items are very popular and some are not, I’ve calculated the weights for them and try to use them with reduction=mean for NLLLoss (it’s actually CE because I use LogSoftmax right before), however pytorch doesn’t work with multi label classification
For example I have labels in format
[[0, 1, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0]]
which could be interpreted as
[1, 0, 3] and it will work with nn.NLLLoss, but I can have
[[0,1,0,1,0]] labels, meaning what user bought 2 out of 6 items and I need to classify both for this sample.
So I implement loss myself as
loss = -outputs * targets * weight
which is the unreduced (i.e. with
reduction set to
'none' ) loss and it matches nn.NLLLoss for
But I want to get
'mean' loss, and when I use the formula from official repo
loss.sum() / weight.sum() the loss is not equal to
nn.NLLLoss(weight=weight, reduction='mean') I tried just getting the mean of my loss and results seem more correlating but anyway different.
My question is what formula should I use to implement NLL loss with weights and with reduction mean to match nn.NLLLoss for simple samples like
[0,0,1,0,0] so I can then use this formula for CF samples like
[0,1,0,1,0] (refers to item 1 and item 3 are present)
simple reproducible test
import torch size = (3, 5) kwargs = dict( reduction='mean', weight=torch.randn(size=(1, 5)) ) outputs = torch.randn(*size, requires_grad=True) targets = torch.tensor([1, 0, 3]) targets_one_hot = torch.tensor([[0, 1, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0]]) def nll_loss(outputs, targets: torch.Tensor, weight, reduction): loss = -outputs * targets * weight loss = loss.sum(dim=-1, keepdim=True) if reduction == 'none': return loss elif reduction == 'sum': return loss.sum() elif reduction == 'mean': return loss.sum() / weight.sum() src_loss_value = torch.nn.functional.nll_loss(outputs, targets, **kwargs) own_loss_value = nll_loss(outputs, targets_one_hot, **kwargs) assert torch.allclose(src_loss_value, own_loss_value)
Any ideas would be great, thank you :).