Replace part value of matrix with new value

For simple cases, Given a [2,2] tensor,
[[1,1],
[1,1]]
we can use “scatter” function to replace index [0,1] and [0,0] with new value (such as zero),
[[0,0],
[1,1]];
but if given a [2,2,2] tensor
[[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]]
, how can I achieve the same operation as above with scatter funtion?
[[[0, 0],
[1, 1]],
[[0, 0],
[1, 1]]]

You can directly index the tensor as seen here:

x = torch.ones(2, 2, 2)
idx = torch.tensor([0])
x[:, idx] = 0.
print(x)
> tensor([[[0., 0.],
           [1., 1.]],

          [[0., 0.],
           [1., 1.]]])