# Reduction mean weighted for NLL loss Multi-Label classification CF RecSys

Hi, I have implemented a Variational AutoEncoder for Collaborative Filterting (user-item) and since I have very sparse data, where some items are very popular and some are not, I’ve calculated the weights for them and try to use them with reduction=mean for NLLLoss (it’s actually CE because I use LogSoftmax right before), however pytorch doesn’t work with multi label classification
For example I have labels in format

[[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 0, 0, 1, 0]]

which could be interpreted as [1, 0, 3] and it will work with nn.NLLLoss, but I can have [[0,1,0,1,0]] labels, meaning what user bought 2 out of 6 items and I need to classify both for this sample.
So I implement loss myself as
loss = -outputs * targets * weight
which is the unreduced (i.e. with reduction set to 'none' ) loss and it matches nn.NLLLoss for reduction='none'
But I want to get reduction='mean' loss, and when I use the formula from official repo loss.sum() / weight.sum() the loss is not equal to nn.NLLLoss(weight=weight, reduction='mean') I tried just getting the mean of my loss and results seem more correlating but anyway different.
My question is what formula should I use to implement NLL loss with weights and with reduction mean to match nn.NLLLoss for simple samples like [0,0,1,0,0] so I can then use this formula for CF samples like [0,1,0,1,0] (refers to item 1 and item 3 are present)

simple reproducible test

import torch

size = (3, 5)
kwargs = dict(
reduction='mean',
weight=torch.randn(size=(1, 5))
)
targets = torch.tensor([1, 0, 3])
targets_one_hot = torch.tensor([[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 0, 0, 1, 0]])

def nll_loss(outputs, targets: torch.Tensor, weight, reduction):
loss = -outputs * targets * weight
loss = loss.sum(dim=-1, keepdim=True)
if reduction == 'none':
return loss
elif reduction == 'sum':
return loss.sum()
elif reduction == 'mean':
return loss.sum() / weight.sum()

src_loss_value = torch.nn.functional.nll_loss(outputs, targets, **kwargs)
own_loss_value = nll_loss(outputs, targets_one_hot, **kwargs)
assert torch.allclose(src_loss_value, own_loss_value)

Any ideas would be great, thank you :).