Keeping the modification inplace after masking a tensor

Let say for the example I have:

x = torch.randn((10, 20))
mask = torch.bernoulli(torch.full(size=(10,), fill_value=0.50)).type(torch.bool)
i = 0
oh = F.one_hot(1, num_classes=20).type(torch.float)

When I do

x[mask][i] = oh

I’m hoping that one of the line of x is replaced by oh but it’s not the case, the change is not kept why? Do you have an alternative for it to be the case

I found a way to make it work but it’s not pretty:

i_star = torch.arange(x.size(0))[mask][i]
x[i_star] = oh

I’m sure there is a more pythonic way

I think because you are indexing an intermediate result by “chaining” the indexing operation.

I’m not sure I fully understand your use case, but it seems you want to index mask with i first so I guess this should also work:

x[mask.nonzero()[i]] = oh

Just adding an additional note to @ptrblck’s answer, I think x[mask] is a type of advanced indexing and according to PyTorch documentation, this indexing returns a copy of the tensor rather than a view of the underlying tensor.

Quoting from Tensor Views — PyTorch 1.11.0 documentation,

When accessing the contents of a tensor via indexing, PyTorch follows Numpy behaviors that basic indexing returns views, while advanced indexing returns a copy.

Thus x[mask][i] = oh replaces the elements in the copy of the tensor rather than the view of the underlying tensor. So, x is unmodified.