For simple cases, Given a [2,2] tensor，

[[1,1],

[1,1]]

we can use “scatter” function to replace index [0,1] and [0,0] with new value (such as zero),

[[0,0],

[1,1]];

but if given a [2,2,2] tensor

[[[1., 1.],

[1., 1.]],

[[1., 1.],

[1., 1.]]]

, how can I achieve the same operation as above with scatter funtion?

[[[0, 0],

[1, 1]],

[[0, 0],

[1, 1]]]

You can directly index the tensor as seen here:

```
x = torch.ones(2, 2, 2)
idx = torch.tensor([0])
x[:, idx] = 0.
print(x)
> tensor([[[0., 0.],
[1., 1.]],
[[0., 0.],
[1., 1.]]])
```