Question about the usage of tensors used as indices

I make a simpified snippet of code as following:

import torch
import torch.nn.functional as F

pred = torch.Tensor([[-4.5665, -4.6113, -4.6146, -4.5443, -4.6290, -4.5983],
        [-4.5546, -4.5603, -4.5680, -4.6010, -4.6175, -4.6058],
        [-4.5971, -4.5846, -4.5514, -4.5913, -4.6732, -4.5830],
        [-4.6047, -4.5135, -4.5392,  -4.5392, -4.6226, -4.5826],
        [-4.5977, -4.5363, -4.6116, -4.5131, -4.5517, -4.6280],
        [-4.6122, -4.5613, -4.5750,  -4.5170, -4.6081, -4.6125]])  # size:6×6
score = torch.Tensor([0, 0.6736, 0, 0, 0.5477, 0, 0])
loss = torch.Tensor([[1.0942e-06, 9.5768e-07, 9.4854e-07,  1.1692e-06, 9.0859e-07, 9.9546e-07],
        [1.1336e-06, 1.1148e-06, 1.0895e-06, 9.8755e-07, 9.4024e-07, 9.7346e-07],
        [9.9918e-07, 1.0370e-06, 1.1447e-06,  1.0165e-06, 7.9676e-07, 1.0419e-06],
        [9.7682e-07, 1.2811e-06, 1.1868e-06,  1.1869e-06, 9.2611e-07, 1.0433e-06],
        [9.9742e-07, 1.1970e-06, 9.5695e-07, 1.2826e-06, 1.1437e-06, 9.1127e-07],
        [9.5525e-07, 1.1114e-06, 1.0671e-06, 1.2680e-06, 9.6688e-07, 9.5447e-07]])  # size:6×6
print(loss)
pos = torch.LongTensor([1, 4])  # [1, 4] indicate the 2th and 5th row of loss
pos_label = torch.LongTensor([3, 5])  # [3, 5] indicate the 4th and 6th column of loss
loss[pos, pos_label] = F.binary_cross_entropy_with_logits(pred[pos, pos_label], score[pos], reduction='none')
print(loss)

The snippet of code is updating 2 position values of loss to a new calculated BCE one. What i want to do is checking the element of pos_label, if a element equals 5, then the updated value should multiply 3(in my real case, 5 means a class id whose traning set is small, so i want to increse its loss weight). Currently the updated loss is :

tensor([[1.0942e-06, 9.5768e-07, 9.4854e-07, 1.1692e-06, 9.0859e-07, 9.9546e-07],
        [1.1336e-06, 1.1148e-06, 1.0895e-06, 3.1092e+00, 9.4024e-07, 9.7346e-07],
        [9.9918e-07, 1.0370e-06, 1.1447e-06, 1.0165e-06, 7.9676e-07, 1.0419e-06],
        [9.7682e-07, 1.2811e-06, 1.1868e-06, 1.1869e-06, 9.2611e-07, 1.0433e-06],
        [9.9742e-07, 1.1970e-06, 9.5695e-07, 1.2826e-06, 1.1437e-06, 2.5445e+00],
        [9.5525e-07, 1.1114e-06, 1.0671e-06, 1.2680e-06, 9.6688e-07, 9.5447e-07]])

What i expect is :

tensor([[1.0942e-06, 9.5768e-07, 9.4854e-07, 1.1692e-06, 9.0859e-07, 9.9546e-07],
        [1.1336e-06, 1.1148e-06, 1.0895e-06, 3.1092e+00, 9.4024e-07, 9.7346e-07],
        [9.9918e-07, 1.0370e-06, 1.1447e-06, 1.0165e-06, 7.9676e-07, 1.0419e-06],
        [9.7682e-07, 1.2811e-06, 1.1868e-06, 1.1869e-06, 9.2611e-07, 1.0433e-06],
        [9.9742e-07, 1.1970e-06, 9.5695e-07, 1.2826e-06, 1.1437e-06, 7.6335e+00],
        [9.5525e-07, 1.1114e-06, 1.0671e-06, 1.2680e-06, 9.6688e-07, 9.5447e-07]])

Please use a comapare tool to find the difference. Thanks in advance !

I make it by myself inspired by this comment

import torch
import torch.nn.functional as F

pred = torch.Tensor([[-4.5665, -4.6113, -4.6146, -4.5443, -4.6290, -4.5983],
        [-4.5546, -4.5603, -4.5680, -4.6010, -4.6175, -4.6058],
        [-4.5971, -4.5846, -4.5514, -4.5913, -4.6732, -4.5830],
        [-4.6047, -4.5135, -4.5392,  -4.5392, -4.6226, -4.5826],
        [-4.5977, -4.5363, -4.6116, -4.5131, -4.5517, -4.6280],
        [-4.6122, -4.5613, -4.5750,  -4.5170, -4.6081, -4.6125]])  # size:6×6
score = torch.Tensor([0, 0.6736, 0, 0, 0.5477, 0, 0])
loss = torch.Tensor([[1.0942e-06, 9.5768e-07, 9.4854e-07,  1.1692e-06, 9.0859e-07, 9.9546e-07],
        [1.1336e-06, 1.1148e-06, 1.0895e-06, 9.8755e-07, 9.4024e-07, 9.7346e-07],
        [9.9918e-07, 1.0370e-06, 1.1447e-06,  1.0165e-06, 7.9676e-07, 1.0419e-06],
        [9.7682e-07, 1.2811e-06, 1.1868e-06,  1.1869e-06, 9.2611e-07, 1.0433e-06],
        [9.9742e-07, 1.1970e-06, 9.5695e-07, 1.2826e-06, 1.1437e-06, 9.1127e-07],
        [9.5525e-07, 1.1114e-06, 1.0671e-06, 1.2680e-06, 9.6688e-07, 9.5447e-07]])  # size:6×6
print(loss)
pos = torch.LongTensor([1, 4])  # [1, 4] indicate the 2th and 5th row of loss
pos_label = torch.LongTensor([3, 5])  # [3, 5] indicate the 4th and 6th column of loss
print(loss[pos, pos_label])
updated_value = F.binary_cross_entropy_with_logits(pred[pos, pos_label], score[pos], reduction='none')
print(updated_value)

boolean_mask = sum(pos_label.eq(i) for i in [0, 4, 5]).bool()
print(boolean_mask)
print(updated_value[boolean_mask])
updated_value[boolean_mask] = 3*updated_value[boolean_mask]
loss[pos, pos_label] = updated_value
print(loss[pos, pos_label])
print(loss)
1 Like