Applying mask using torch.where function

Hi,

I am trying to apply condition on tensors. What I am trying to achieve is: When mask is true then use the value from X otherwise Y. Logic works fine using np.where function. I am trying to achieve same using torch.where. Anything i am missing on the following?

weights_x and weights_y both have size:

torch.Size([32, 1, 5, 5])
torch.Size([32, 1, 5, 5])

mask= th.tensor([False, True, False, True, True, False, False, True, True, True, False, True, False, True, True, True, False, True, True, False, False, True, False, False, False, False, True, False, False, False, True, True])

Numpy version
np.where(mask, weights_x, weights_y)

Pytorch version
th.where(mask == True, weights_x, weights_y)

Error:

if isinstance(args_, tuple):
--> 414             response = command_method(*args_, **kwargs_)
    415         else:
    416             response = command_method(args_, **kwargs_)

RuntimeError: The size of tensor a (32) must match the size of tensor b (5) at non-singleton dimension 3

numpy as well as PyTorch will both fail with a shape mismatch error so you would need to unsqueeze the mask tensor:

weights_x = torch.zeros([32, 1, 5, 5])
weights_y = torch.ones([32, 1, 5, 5])
mask= torch.tensor([False, True, False, True, True, False, False, True, True, True, False, True, False, True, True, True, False, True, True, False, False, True, False, False, False, False, True, False, False, False, True, True])

np.where(mask, weights_x, weights_y) # fails
> ValueError: operands could not be broadcast together with shapes (32,) (32,1,5,5) (32,1,5,5) 

torch.where(mask == True, weights_x, weights_y) # fails
> RuntimeError: The size of tensor a (32) must match the size of tensor b (5) at non-singleton dimension 3

# works
arr = np.where(mask[:, None, None, None], weights_x, weights_y)
out = torch.where(mask[:, None, None, None], weights_x, weights_y)

print((arr == out.numpy()).all())
> True

@ptrblck thanks for the reply. It worked fine. How can I apply this where condition along y-axis? I have mask of shape torch.Size([2048]) and tensor of shape torch.Size([62, 2048]). I want to apply the mask along y-axis, can you please comment? Thanks.