Hi,
I am trying to apply condition on tensors. What I am trying to achieve is: When mask is true then use the value from X otherwise Y. Logic works fine using np.where
function. I am trying to achieve same using torch.where
. Anything i am missing on the following?
weights_x
and weights_y
both have size:
torch.Size([32, 1, 5, 5])
torch.Size([32, 1, 5, 5])
mask= th.tensor([False, True, False, True, True, False, False, True, True, True, False, True, False, True, True, True, False, True, True, False, False, True, False, False, False, False, True, False, False, False, True, True])
Numpy version
np.where(mask, weights_x, weights_y)
Pytorch version
th.where(mask == True, weights_x, weights_y)
Error:
if isinstance(args_, tuple):
--> 414 response = command_method(*args_, **kwargs_)
415 else:
416 response = command_method(args_, **kwargs_)
RuntimeError: The size of tensor a (32) must match the size of tensor b (5) at non-singleton dimension 3