Specify dimension in torch.where

I have some problem with torch.where(...). My code relies on many small matrix-vector product operations which I would like to vectorise over the third dimension. As an example consider the following code snippet:

A = torch.tensor([1,2,3,4,5,6,7,8,9,10,11,12]).view(3,2,2)
tensor([[[ 1,  2],
         [ 3,  4]],

        [[ 5,  6],
         [ 7,  8]],

        [[ 9, 10],
         [11, 12]]])

x = torch.tensor([1,2,3,4,5,6]).view(3,2,1)
tensor([[[1],
         [2]],

        [[3],
         [4]],

        [[5],
         [6]]])

y = torch.matmul( A, x )
tensor([[[  5],
         [ 11]],

        [[ 39],
         [ 53]],

        [[105],
         [127]]])

This works just fine but my problem is that I need to catch some corner cases. Let

idx = torch.tensor([1,2,3])

What I want to achieve is the following

torch.where(idx > 1, A, 0.0)

which lead to the following error message

RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 2

Making matrix A of size [2,2,3] does work (i.e. no errors) but of course does not give any meaningful results. Any help is appreciated.

I’m not sure what the expected results are but given your idx tensor has 3 values I would guess you want to index A in dim0. If so, indexing should work:

A[idx>1] = 0.

Thanks. Could you also let me know the C++ equivalent of this?