How to implement this tricky scatter?

Suppose I have a tensor A = [row, col], for example:

 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1]])

Now, the values in some of rows of A will be modified to value 0.

I have a index tensor B specifying which column the value should be modified and a ‘indicator’ tensor C indicating which row (the row in C with value 1) should be modified. For example, C = [[1], [0], [1]] and B = [[1], [7]] are saying that the first row and the last row of and the corresponding column 1 and 7 of A will be changed to value 0:

 tensor([[1, 0, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 0, 1]])

How should I do this using pytorch?

Since slicing will yield in-place result, I would prefer inplement it using scatter.

This code should work:

x = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1, 1, 1, 1, 1]])

C = torch.tensor([[1], [0], [1]]).squeeze(1).nonzero().squeeze(1)
B = torch.tensor([[1], [7]]).squeeze(1)
x[C, B] = 0
print(x)

Note that I kept your original shapes and had to use the squeeze operations. If you could create the indices in the expected shape, you could of course remove them.
Also, I’ve converted the bool index mask C to a value index in order to be able to use C and B together.

1 Like