Convert for loop operations in pytorch tensor operation

Can anyone help me to convert this block of code into pytorch tensor operation which will be efficient in large scale. ? or tell me how can I do it. Thanks

import torch

outputs = torch.rand(3,5)
labels = torch.randint(0, 4, (3,))

def loss_swap(outputs, labels):
    _, predicted = outputs.max(0)
    correct = predicted.eq(labels)
    for i in range(correct.shape[0]):
        if correct[i].item() == False:
            c_index = labels[i].item()
            p_index = predicted[i].item()
            tmp1, tmp2 = outputs[i, p_index].item(), outputs[i, c_index].item()
            outputs[i, c_index] = tmp1
            outputs[i, p_index] = tmp2
    return outputs

Hi Fahad!

First, in those locations where predicted equals labels, it doesn’t
matter whether you swap the values or not, so you can go ahead
and swap them.

Second, you can use tensor indexing to swap all of the values at
once rather than swapping them row by row in a loop.

(Note, you have a couple of typos in your example code that I have
corrected, below.)

Here is an example script:

import torch
print (torch.__version__)

_ = torch.manual_seed (2021)

outputs = torch.rand (3, 5)
# labels = torch.randint(0, 4, (3,))   # labels would only be 0, 1, 2, 3
labels = torch.randint(0, 5, (3,))

def loss_swap(outputs, labels):
    # _, predicted = outputs.max(0)   # predicted would only be 0, 1, 2
    _, predicted = outputs.max (1)
    correct = predicted.eq(labels)
    for i in range(correct.shape[0]):
        if correct[i].item() == False:
            c_index = labels[i].item()
            p_index = predicted[i].item()
            tmp1, tmp2 = outputs[i, p_index].item(), outputs[i, c_index].item()
            outputs[i, c_index] = tmp1
            outputs[i, p_index] = tmp2
    return outputs

def loss_swapB(outputs, labels):
    _, predicted = outputs.max (1)
    tmp = outputs[torch.arange (len (labels)), predicted]
    outputs[torch.arange (len (labels)), predicted] = outputs[torch.arange (len (labels)), labels]
    outputs[torch.arange (len (labels)), labels] = tmp
    return outputs

out = loss_swap (outputs.clone(), labels)
outB = loss_swapB (outputs.clone(), labels)

print ('outputs =')
print (outputs)
print ('labels =')
print (labels)
print ('out =')
print (out)

print ('equal:', outB.equal (out))

outputs = torch.rand (10, 100)
labels = torch.randint(0, 100, (10,))

print ('outputs.shape =', outputs.shape, 'labels.shape =', labels.shape)
print ('equal:', loss_swapB (outputs.clone(), labels).equal (loss_swap (outputs.clone(), labels)))

And here is its output:

outputs =
tensor([[0.1304, 0.5134, 0.7426, 0.7159, 0.5705],
        [0.1653, 0.0443, 0.9628, 0.2943, 0.0992],
        [0.8096, 0.0169, 0.8222, 0.1242, 0.7489]])
labels =
tensor([3, 0, 0])
out =
tensor([[0.1304, 0.5134, 0.7159, 0.7426, 0.5705],
        [0.9628, 0.0443, 0.1653, 0.2943, 0.0992],
        [0.8222, 0.0169, 0.8096, 0.1242, 0.7489]])
equal: True
outputs.shape = torch.Size([10, 100]) labels.shape = torch.Size([10])
equal: True

Best.

K. Frank