Update specific columns of each row in a torch Tensor

I read the post about selecting specific columns in a row, but I’m struggling to apply it to updating said 2D tensor.

I’m trying to do essentially the following to 2D torch tensors.

x = np.array( [[ 0,  1,  2],
               [ 3,  4,  5],
               [ 6,  7,  8],
               [ 9, 10, 11]])
           
idx = np.array([1,0,0,2])

update_values = np.array([111,222,333,444])

x[np.arange(0,4),idx] = update_values

Desired results:

x= array([[  0, 111,   2],
          [222,   4,   5],
          [333,   7,   8],
          [  9,  10, 444]])

Here’s an easy way to do it:

x = torch.FloatTensor([[ 0,  1,  2],
                                    [ 3,  4,  5],
                                    [ 6,  7,  8],
                                    [ 9, 10, 11]])

idx = torch.LongTensor([1, 0, 0, 2])
j = torch.arange(x.size(0)).long()

update_values = torch.FloatTensor([111,222,333,444])

x[j, idx] = update_values

Cheers!

1 Like