Torch.where weird behavior

I have the following expression:

import torch
torch.manual_seed(0)

x = torch.rand(3,7,7,5)
mask1 = x[:,:,:,4] > .5 # <- [3,7,7]
x1, x2 = torch.rand_like(x), torch.rand_like(x)
mask2 = torch.rand(len(x[mask1])) > .5 # <- [74]

print(x[mask1].size(), mask2.size(), x1[mask1].size(), x2[mask1].size())
x[mask1] = torch.where(mask2, x1[mask1], x2[mask1])

The output of the print statement is:

torch.Size([74, 5]) torch.Size([74]) torch.Size([74, 5]) torch.Size([74, 5])

It seems that my tensors are compatible with each other. However, the line with torch.where gives the following Runtime Error: The size of tensor a (74) must match the size of tensor b (5) at non-singleton dimension 1.

Does anyone know what is the problem here?

When I do the torch.where manually like this:

x[mask1][mask2] = x1[mask1][mask2]
x[mask1][~mask2] = x2[mask1][~mask2]

It seems to work fine and to me that is the same logic as with the torch.where. Can someone explain?

Edit: Fixed typo

Hi i4!

First, you have a typo in this line – you want x2[mask1] (rather than
x2[mask2]).

Your core problem is that the arguments you pass to where() don’t
have the same shapes, so where() tries to broadcast them, but the
shapes aren’t broadcastable.

Specifically, the last dimension of mask2 is 74, while the last dimension
of, for example, x1[mask1] is 5. Adding a trailing singleton dimension
to mask2 (and fixing the typo) should do what you want:

x[mask1] = torch.where (mask2.unsqueeze (-1), x1[mask1], x2[mask1])

As an aside, this won’t do what you think it will – it doesn’t actually modify
x.

x[mask1] = x1[mask1] would modify x, but, perhaps counterintuitively,
when you index twice in a row, as in x[mask1][mask2], somewhere along
the line pytorch creates a new tensor, and you can’t use x[mask1][mask2]
to assign back into x.

Best.

K. Frank

1 Like

Hello,

Thank you so much for your reply :slight_smile: