I have some problem with `torch.where(...)`

. My code relies on many small matrix-vector product operations which I would like to vectorise over the third dimension. As an example consider the following code snippet:

```
A = torch.tensor([1,2,3,4,5,6,7,8,9,10,11,12]).view(3,2,2)
tensor([[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]],
[[ 9, 10],
[11, 12]]])
x = torch.tensor([1,2,3,4,5,6]).view(3,2,1)
tensor([[[1],
[2]],
[[3],
[4]],
[[5],
[6]]])
y = torch.matmul( A, x )
tensor([[[ 5],
[ 11]],
[[ 39],
[ 53]],
[[105],
[127]]])
```

This works just fine but my problem is that I need to catch some corner cases. Let

```
idx = torch.tensor([1,2,3])
```

What I want to achieve is the following

```
torch.where(idx > 1, A, 0.0)
```

which lead to the following error message

```
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 2
```

Making matrix `A`

of size [2,2,3] *does* work (i.e. no errors) but of course does not give any meaningful results. Any help is appreciated.