Setting values of a tensor based on given indices of corresponding rows

I’ve got a tensor A with shape (M, N), and have another tensor B with shape (M, P) and with values of given indices in corresponding rows of A. Now I would like to set the values of A with corresponding indices in B to 0.

For example:

In[1]: import torch
       A = torch.tensor([range(1,11), range(1,11), range(1,11)])
       A
Out[1]: 
tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]])
In[2]: B = torch.tensor([[1,2], [2,3], [3,5]])
       B
Out[2]: 
tensor([[1, 2],
        [2, 3],
        [3, 5]])

The objective is to set the value of the element with index 1,2 in the first row, 2,3 in the second row, and 3,5 in the third row of A to 0, i.e., setting A to

tensor([[ 1,  0,  0,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  0,  0,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  0,  5,  0,  7,  8,  9, 10]])

I have applied row by row for loop, and also tried scatter:

zeros = torch.zeros(A.shape, dtype=torch.float).to("cuda")
A = A.scatter_(1, B, zeros)

The two methods work fine, but all give quite poor performance. Actually, I infer that some efficient approach should exist based on an error before. I initially used A[:, B] = 0. This would set all the indices of appeared in B to 0, regardless of the row. However, the training speed improved drastically when doing A[:, B] = 0.

Is there any way to implement this more efficiently?

You could try this indexing method:

A = torch.tensor([range(1,11), range(1,11), range(1,11)])
B = torch.tensor([[1,2], [2,3], [3,5]])
A[torch.arange(A.size(0)).unsqueeze(1), B] = 0.
print(A)
> tensor([[ 1,  0,  0,  4,  5,  6,  7,  8,  9, 10],
          [ 1,  2,  0,  0,  5,  6,  7,  8,  9, 10],
          [ 1,  2,  3,  0,  5,  0,  7,  8,  9, 10]])

but I’m unsure if you would see a speedup compared to scatter_.