Hi Fahad!
First, in those locations where predicted
equals labels
, it doesn’t
matter whether you swap the values or not, so you can go ahead
and swap them.
Second, you can use tensor indexing to swap all of the values at
once rather than swapping them row by row in a loop.
(Note, you have a couple of typos in your example code that I have
corrected, below.)
Here is an example script:
import torch
print (torch.__version__)
_ = torch.manual_seed (2021)
outputs = torch.rand (3, 5)
# labels = torch.randint(0, 4, (3,)) # labels would only be 0, 1, 2, 3
labels = torch.randint(0, 5, (3,))
def loss_swap(outputs, labels):
# _, predicted = outputs.max(0) # predicted would only be 0, 1, 2
_, predicted = outputs.max (1)
correct = predicted.eq(labels)
for i in range(correct.shape[0]):
if correct[i].item() == False:
c_index = labels[i].item()
p_index = predicted[i].item()
tmp1, tmp2 = outputs[i, p_index].item(), outputs[i, c_index].item()
outputs[i, c_index] = tmp1
outputs[i, p_index] = tmp2
return outputs
def loss_swapB(outputs, labels):
_, predicted = outputs.max (1)
tmp = outputs[torch.arange (len (labels)), predicted]
outputs[torch.arange (len (labels)), predicted] = outputs[torch.arange (len (labels)), labels]
outputs[torch.arange (len (labels)), labels] = tmp
return outputs
out = loss_swap (outputs.clone(), labels)
outB = loss_swapB (outputs.clone(), labels)
print ('outputs =')
print (outputs)
print ('labels =')
print (labels)
print ('out =')
print (out)
print ('equal:', outB.equal (out))
outputs = torch.rand (10, 100)
labels = torch.randint(0, 100, (10,))
print ('outputs.shape =', outputs.shape, 'labels.shape =', labels.shape)
print ('equal:', loss_swapB (outputs.clone(), labels).equal (loss_swap (outputs.clone(), labels)))
And here is its output:
outputs =
tensor([[0.1304, 0.5134, 0.7426, 0.7159, 0.5705],
[0.1653, 0.0443, 0.9628, 0.2943, 0.0992],
[0.8096, 0.0169, 0.8222, 0.1242, 0.7489]])
labels =
tensor([3, 0, 0])
out =
tensor([[0.1304, 0.5134, 0.7159, 0.7426, 0.5705],
[0.9628, 0.0443, 0.1653, 0.2943, 0.0992],
[0.8222, 0.0169, 0.8096, 0.1242, 0.7489]])
equal: True
outputs.shape = torch.Size([10, 100]) labels.shape = torch.Size([10])
equal: True
Best.
K. Frank