Question on masked_scatter_

I have difficulties with masked scatter.
Say I have a mask = [1 0 1 0] (a torch.ByteTensor of size 4x1).
And I have two tensors y and z, both of size (3, 4, 4).
Now I want that the rows of y, corresponding to indexes #0 and #2 in the mask, must be filled by the corresponding elements from the corresponding positions in z.
I do this:
y.masked_scatter_(mask, z)
But instead of getting rows 0 and 2 in all channels of yequal to those from z, I get some strange permutated rows. I suspect that I should do an expand_as somewhere?

Here’s the code to reproduce

y = torch.rand(3, 4, 4)   
z = torch.rand(3, 4, 4)   
x = torch.rand(4,1)
mask = x.ge(0.5)
y.masked_scatter_(mask, z)

your mask is being automatically broadcasted to the shape of y. for masked_scatter_, the mask and y should be of the same shape. Also, masked_scatter_ doesn’t index into z, it sequentially copies over z's values

The more correct thing to do would be construct all ones in the rows you care about, and then do direct assignment

import torch

y = torch.rand(3, 2, 2)
z = torch.rand(3, 2, 2)
x = torch.rand(3)
mask = x.ge(0.5)

# get row indices
indices = mask.nonzero().squeeze(1)

y[indices] = z[indices]
2 Likes