Suppose I have a tensor A = [row, col]
, for example:
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1]])
Now, the values in some of rows of A
will be modified to value 0.
I have a index tensor B
specifying which column the value should be modified and a ‘indicator’ tensor C
indicating which row (the row in C
with value 1) should be modified. For example, C = [[1], [0], [1]]
and B = [[1], [7]]
are saying that the first row and the last row of and the corresponding column 1 and 7 of A
will be changed to value 0:
tensor([[1, 0, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 0, 1]])
How should I do this using pytorch?
Since slicing will yield in-place result, I would prefer inplement it using scatter
.