import torch
x = torch.arange(25).view((5, 5))
y = torch.tensor([[3, 4, 2, 2], [0, 4, 1, 1], [2, 4, 3, 1]]) # shape: (3, 4)
result = torch.zeros(3, 4)
for i in range(1, 4):
current_y = y[:, i] # shape: (3,)
prev_y = y[:, i - 1] # shape: (3,)
result[:, i] = x[prev_y, current_y]
print(result)
Here is the running result:
I think this procedure can be executed in parallel, but I don’t know how. Can anyone help me?