Why there is a constraint on the size of tensor in torch.where

I basically want to compute torch.where on outer dimensions. Here is the snippet from ipython-

In [599]: a.shape
Out[599]: torch.Size([32, 4, 13])

In [600]: b.shape
Out[600]: torch.Size([32, 4, 13])

In [601]: c.shape
Out[601]: torch.Size([32])

In [602]: torch.where(c, a, b)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-602-4767f96aa34a> in <module>()
----> 1 torch.where(c, a, b)

RuntimeError: The size of tensor a (32) must match the size of tensor b (13) at non-singleton dimension 2

Right now I am doing this by following command -
out=torch.stack([a[i] if c[i] else b[i] for i, val in enumerate(c)])
Please let me know if I haven’t explored any other useful torch API or a better pythonic way to perform the desired task.

The tensors should be broadcastable, so this would work:

a = torch.randn(32, 4, 13)
b = torch.randn(32, 4, 13)

c = torch.empty((32, 1, 1), dtype=torch.uint8).random_(2)

d = torch.where(c, a, b)

In your case, just get a new view on c:

c = c.view(-1, 1, 1)
2 Likes