Looking for parallel implementation for a piece of code

import torch
x = torch.arange(25).view((5, 5))
y = torch.tensor([[3, 4, 2, 2], [0, 4, 1, 1], [2, 4, 3, 1]])  # shape: (3, 4)
result = torch.zeros(3, 4)
for i in range(1, 4):
    current_y = y[:, i]  # shape: (3,)
    prev_y = y[:, i - 1]  # shape: (3,)
    result[:, i] = x[prev_y, current_y]
print(result)

Here is the running result:


I think this procedure can be executed in parallel, but I don’t know how. Can anyone help me?