Consider the following tensors:
n = 2
a = torch.rand(3,7,7,n,5)
b = torch.rand(3,7,7,n)
Is it possible to select the elements of a
, where the corresponding element of b
is the largest without reshaping a
and b
beforehand? The new tensor must have the shape (3,7,7,5)
. That can be easily achieved with torch.where
if n = 2
but how can it be done for n != 2
? I imagine something like torch.index_select
, though it only works for 1-dimensional indices.