I have difficulties with masked scatter.
Say I have a mask = [1 0 1 0] (a torch.ByteTensor of size 4x1).
And I have two tensors y and z, both of size (3, 4, 4).
Now I want that the rows of y, corresponding to indexes #0 and #2 in the mask, must be filled by the corresponding elements from the corresponding positions in z.
I do this:
y.masked_scatter_(mask, z)
But instead of getting rows 0 and 2 in all channels of yequal to those from z, I get some strange permutated rows. I suspect that I should do an expand_as somewhere?
Here’s the code to reproduce
y = torch.rand(3, 4, 4)
z = torch.rand(3, 4, 4)
x = torch.rand(4,1)
mask = x.ge(0.5)
y.masked_scatter_(mask, z)