I have some problem with torch.where(...)
. My code relies on many small matrix-vector product operations which I would like to vectorise over the third dimension. As an example consider the following code snippet:
A = torch.tensor([1,2,3,4,5,6,7,8,9,10,11,12]).view(3,2,2)
tensor([[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]],
[[ 9, 10],
[11, 12]]])
x = torch.tensor([1,2,3,4,5,6]).view(3,2,1)
tensor([[[1],
[2]],
[[3],
[4]],
[[5],
[6]]])
y = torch.matmul( A, x )
tensor([[[ 5],
[ 11]],
[[ 39],
[ 53]],
[[105],
[127]]])
This works just fine but my problem is that I need to catch some corner cases. Let
idx = torch.tensor([1,2,3])
What I want to achieve is the following
torch.where(idx > 1, A, 0.0)
which lead to the following error message
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 2
Making matrix A
of size [2,2,3] does work (i.e. no errors) but of course does not give any meaningful results. Any help is appreciated.