Given a 2D pytorch tensor, I want to update specific columns for each row as following:
import torch
a = torch.zeros(3,7)
index = torch.LongTensor([2, 1, 3])
for i in range(a.shape[0]):
a[i][index[i]:] = 1
Expected output data:
tensor([[0., 0., 1., 1., 1., 1., 1.],
[0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 1., 1., 1., 1.]])
I’m wondering how this could be vectorized using pytorch operations?