Hi, I have implemented a Variational AutoEncoder for Collaborative Filterting (user-item) and since I have very sparse data, where some items are very popular and some are not, I’ve calculated the weights for them and try to use them with reduction=mean for NLLLoss (it’s actually CE because I use LogSoftmax right before), however pytorch doesn’t work with multi label classification
For example I have labels in format
[[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 0, 0, 1, 0]]
which could be interpreted as [1, 0, 3]
and it will work with nn.NLLLoss, but I can have [[0,1,0,1,0]]
labels, meaning what user bought 2 out of 6 items and I need to classify both for this sample.
So I implement loss myself as
loss = -outputs * targets * weight
which is the unreduced (i.e. with reduction
set to 'none'
) loss and it matches nn.NLLLoss for reduction='none'
But I want to get reduction
='mean'
loss, and when I use the formula from official repo loss.sum() / weight.sum()
the loss is not equal to nn.NLLLoss(weight=weight, reduction='mean')
I tried just getting the mean of my loss and results seem more correlating but anyway different.
My question is what formula should I use to implement NLL loss with weights and with reduction mean to match nn.NLLLoss for simple samples like [0,0,1,0,0]
so I can then use this formula for CF samples like [0,1,0,1,0]
(refers to item 1 and item 3 are present)
simple reproducible test
import torch
size = (3, 5)
kwargs = dict(
reduction='mean',
weight=torch.randn(size=(1, 5))
)
outputs = torch.randn(*size, requires_grad=True)
targets = torch.tensor([1, 0, 3])
targets_one_hot = torch.tensor([[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 0, 0, 1, 0]])
def nll_loss(outputs, targets: torch.Tensor, weight, reduction):
loss = -outputs * targets * weight
loss = loss.sum(dim=-1, keepdim=True)
if reduction == 'none':
return loss
elif reduction == 'sum':
return loss.sum()
elif reduction == 'mean':
return loss.sum() / weight.sum()
src_loss_value = torch.nn.functional.nll_loss(outputs, targets, **kwargs)
own_loss_value = nll_loss(outputs, targets_one_hot, **kwargs)
assert torch.allclose(src_loss_value, own_loss_value)
Any ideas would be great, thank you :).