Indexing in multiple dimensions according to another tensor

Consider the following tensors:

n = 2
a = torch.rand(3,7,7,n,5)
b = torch.rand(3,7,7,n)

Is it possible to select the elements of a, where the corresponding element of b is the largest without reshaping a and b beforehand? The new tensor must have the shape (3,7,7,5). That can be easily achieved with torch.where if n = 2 but how can it be done for n != 2? I imagine something like torch.index_select, though it only works for 1-dimensional indices.

Apparently, It can be achieved as follows:

b = b.argmax(-1, keepdims=True).unsqueeze(-1).expand(-1,-1,-1,-1,5))
torch.gather(a, 3, b)