For simple cases, Given a [2,2] tensor,
[[1,1],
[1,1]]
we can use “scatter” function to replace index [0,1] and [0,0] with new value (such as zero),
[[0,0],
[1,1]];
but if given a [2,2,2] tensor
[[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]]
, how can I achieve the same operation as above with scatter funtion?
[[[0, 0],
[1, 1]],
[[0, 0],
[1, 1]]]
You can directly index the tensor as seen here:
x = torch.ones(2, 2, 2)
idx = torch.tensor([0])
x[:, idx] = 0.
print(x)
> tensor([[[0., 0.],
[1., 1.]],
[[0., 0.],
[1., 1.]]])