How to add a learnable weight for each sample when i compute loss?

Hi everyone. I am developing a new method to learn the weight of each sample(not each batch or each class). But i get some error message, so i try to get some help from you, thanks for your reading!

At first, i wanna to use the loss function pytorch officially provided

nn.BCEWithLogitsLoss(weight=weight_score)

so i write a critetion, when i want to compute the loss wight sample_weight, i will call the forward function

class criterion(nn.Module):
    def init(self):
        super(criterion, self).init()

    def forward(self, true, pred, score):
        loss_fuc = nn.BCEWithLogitsLoss(weight=score)
        return loss_fuc(pred, true)

But there is a error,

The size of tensor a (321) must match the size of tensor b (16) at non-singleton dimension 3

i don`t konw how to solve it, i get the shape of [true, pre, score] as follows:

torch.Size([16, 1, 321, 321])
torch.Size([16, 1, 321, 321])
torch.Size([16])
# batch_size = 16

So the key problem is “How to add a learnable weight of each sample when i compute loss?”

Thanks for your reading, i really want to solve this problem and continue my experiment

I will appreciate if there is another better solution for this problem, thanks again

Are you sure you are printing the shapes just before nn.bce… the true seems to be of shape 16?

yes, the shape of true is torch.Size([16, 1, 321, 321]). Beacause i`m try to build a semantic segmentation application, the label is a gray image and the shape is 1x321x321.