Hi Jaideep!
Yes.
I’m not sure what you mean by “sum weighted rows.”
Yes. First the individual per-sample, per-class loss values are
multiplied by the corresponding weights, then summed, and then
divided by numel()
.
The following script illustrates these behaviors:
import torch
torch.__version__
torch.random.manual_seed (2020)
inp = torch.randn ((3, 5))
trg = torch.randn ((3, 5)).sigmoid()
torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg)
torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg).sum()
torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg).sum() / inp.numel()
torch.nn.BCEWithLogitsLoss (reduction = 'sum') (inp, trg)
torch.nn.BCEWithLogitsLoss (reduction = 'mean') (inp, trg)
# weights per sample (in batch) and class
wta = torch.randn ((3, 5)).abs()
torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg) * wta
(torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg) * wta).sum()
(torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg) * wta).sum() / inp.numel()
torch.nn.BCEWithLogitsLoss (weight = wta, reduction = 'none') (inp, trg)
torch.nn.BCEWithLogitsLoss (weight = wta, reduction = 'sum') (inp, trg)
torch.nn.BCEWithLogitsLoss (weight = wta, reduction = 'mean') (inp, trg)
# weights just per class
wtb = torch.randn ((5,)).abs()
torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg) * wtb
(torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg) * wtb).sum()
(torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg) * wtb).sum() / inp.numel()
torch.nn.BCEWithLogitsLoss (weight = wtb, reduction = 'none') (inp, trg)
torch.nn.BCEWithLogitsLoss (weight = wtb, reduction = 'sum') (inp, trg)
torch.nn.BCEWithLogitsLoss (weight = wtb, reduction = 'mean') (inp, trg)
Here is the output:
>>> import torch
>>> torch.__version__
'1.6.0'
>>> torch.random.manual_seed (2020)
<torch._C.Generator object at 0x7f72eeef96f0>
>>>
>>> inp = torch.randn ((3, 5))
>>> trg = torch.randn ((3, 5)).sigmoid()
>>> torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg)
tensor([[1.0234, 0.5303, 1.3922, 0.6390, 0.4845],
[0.6889, 0.6799, 0.6828, 0.8319, 1.5794],
[0.6843, 0.7374, 0.7945, 0.7029, 1.4041]])
>>> torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg).sum()
tensor(12.8555)
>>> torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg).sum() / inp.numel()
tensor(0.8570)
>>> torch.nn.BCEWithLogitsLoss (reduction = 'sum') (inp, trg)
tensor(12.8555)
>>> torch.nn.BCEWithLogitsLoss (reduction = 'mean') (inp, trg)
tensor(0.8570)
>>>
>>> # weights per sample (in batch) and class
>>> wta = torch.randn ((3, 5)).abs()
>>> torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg) * wta
tensor([[1.2065, 0.2057, 0.4710, 0.8512, 0.1281],
[0.3199, 1.0705, 0.5096, 0.4725, 0.8348],
[1.3118, 0.1872, 0.9706, 0.0520, 2.5681]])
>>> (torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg) * wta).sum()
tensor(11.1597)
>>> (torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg) * wta).sum() / inp.numel()
tensor(0.7440)
>>> torch.nn.BCEWithLogitsLoss (weight = wta, reduction = 'none') (inp, trg)
tensor([[1.2065, 0.2057, 0.4710, 0.8512, 0.1281],
[0.3199, 1.0705, 0.5096, 0.4725, 0.8348],
[1.3118, 0.1872, 0.9706, 0.0520, 2.5681]])
>>> torch.nn.BCEWithLogitsLoss (weight = wta, reduction = 'sum') (inp, trg)
tensor(11.1597)
>>> torch.nn.BCEWithLogitsLoss (weight = wta, reduction = 'mean') (inp, trg)
tensor(0.7440)
>>>
>>> # weights just per class
>>> wtb = torch.randn ((5,)).abs()
>>> torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg) * wtb
tensor([[0.4109, 0.1231, 0.3938, 0.9537, 0.3403],
[0.2766, 0.1579, 0.1932, 1.2416, 1.1093],
[0.2747, 0.1712, 0.2247, 1.0490, 0.9862]])
>>> (torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg) * wtb).sum()
tensor(7.9062)
>>> (torch.nn.BCEWithLogitsLoss (reduction = 'none') (inp, trg) * wtb).sum() / inp.numel()
tensor(0.5271)
>>> torch.nn.BCEWithLogitsLoss (weight = wtb, reduction = 'none') (inp, trg)
tensor([[0.4109, 0.1231, 0.3938, 0.9537, 0.3403],
[0.2766, 0.1579, 0.1932, 1.2416, 1.1093],
[0.2747, 0.1712, 0.2247, 1.0490, 0.9862]])
>>> torch.nn.BCEWithLogitsLoss (weight = wtb, reduction = 'sum') (inp, trg)
tensor(7.9062)
>>> torch.nn.BCEWithLogitsLoss (weight = wtb, reduction = 'mean') (inp, trg)
tensor(0.5271)
Best.
K. Frank