I’ve got a tensor A
with shape (M, N)
, and have another tensor B
with shape (M, P)
and with values of given indices in corresponding rows of A
. Now I would like to set the values of A
with corresponding indices in B
to 0
.
For example:
In[1]: import torch
A = torch.tensor([range(1,11), range(1,11), range(1,11)])
A
Out[1]:
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
In[2]: B = torch.tensor([[1,2], [2,3], [3,5]])
B
Out[2]:
tensor([[1, 2],
[2, 3],
[3, 5]])
The objective is to set the value of the element with index 1,2
in the first row, 2,3
in the second row, and 3,5
in the third row of A
to 0
, i.e., setting A
to
tensor([[ 1, 0, 0, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 0, 0, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 0, 5, 0, 7, 8, 9, 10]])
I have applied row by row for loop, and also tried scatter
:
zeros = torch.zeros(A.shape, dtype=torch.float).to("cuda")
A = A.scatter_(1, B, zeros)
The two methods work fine, but all give quite poor performance. Actually, I infer that some efficient approach should exist based on an error before. I initially used A[:, B] = 0
. This would set all the indices of appeared in B
to 0
, regardless of the row. However, the training speed improved drastically when doing A[:, B] = 0
.
Is there any way to implement this more efficiently?