Assign value to specific columns of each row in pytorch

Given a 2D pytorch tensor, I want to update specific columns for each row as following:

import torch
a = torch.zeros(3,7)
index = torch.LongTensor([2, 1, 3])
for i in range(a.shape[0]):
    a[i][index[i]:] = 1

Expected output data:

tensor([[0., 0., 1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1.]])

I’m wondering how this could be vectorized using pytorch operations?

You could just compare each index to a torch.arange tensor:

b = torch.arange(7).unsqueeze(0) >= index.unsqueeze(1)
b = b.float()
3 Likes

Awesome, thanks for the solution!

1 Like