Looking for parallel implementation for a piece of code

import torch
x = torch.arange(25).view((5, 5))
y = torch.tensor([[3, 4, 2, 2], [0, 4, 1, 1], [2, 4, 3, 1]])  # shape: (3, 4)
result = torch.zeros(3, 4)
for i in range(1, 4):
    current_y = y[:, i]  # shape: (3,)
    prev_y = y[:, i - 1]  # shape: (3,)
    result[:, i] = x[prev_y, current_y]
print(result)

Here is the running result:


I think this procedure can be executed in parallel, but I don’t know how. Can anyone help me?

This should work:

result2 = torch.zeros(3, 4)
result2[:, 1:] = x[y[:, :-1], y[:, 1:]]
print((result==result2).all())
> tensor(True)

Thank you for your answer, you solved my problem. :smiley: