I basically want to compute torch.where on outer dimensions. Here is the snippet from ipython-
In [599]: a.shape
Out[599]: torch.Size([32, 4, 13])
In [600]: b.shape
Out[600]: torch.Size([32, 4, 13])
In [601]: c.shape
Out[601]: torch.Size([32])
In [602]: torch.where(c, a, b)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-602-4767f96aa34a> in <module>()
----> 1 torch.where(c, a, b)
RuntimeError: The size of tensor a (32) must match the size of tensor b (13) at non-singleton dimension 2
Right now I am doing this by following command -
out=torch.stack([a[i] if c[i] else b[i] for i, val in enumerate(c)])
Please let me know if I haven’t explored any other useful torch API or a better pythonic way to perform the desired task.