Slice a tensor on its outermost dimension

Is there a better or more obvious way to do the following operation:

A = torch.randn(3, 3, 4)
temp_m, temp_am = A.max(dim=0)
print(A)
print(temp_am.shape, temp_m.shape)
print("---max----")
print(temp_m)
print("---indexed---")
indexed_max = torch.gather(A, 0, temp_am.expand_as(A))[0]
print(indexed_max)
print(torch.all(indexed_max == temp_m))

My intent is to compute the max value of one tensor and use the argmax output to index other tensors

Maybe this is what you want:

import torch
A = torch.randn(3, 3, 4)
temp_m, temp_am = A.max(dim=0)
print(A)
print(temp_am.shape, temp_m.shape)
print("---max----")
print(temp_m)
print("---indexed---")
indexed_max = torch.gather(A, 0, temp_am.expand_as(A))[0]
print(indexed_max)
print(torch.all(indexed_max == temp_m))

### More direct way ####

idx = A.argmax(0)
result = A.gather(0, idx.unsqueeze(0))
print(result)
print(torch.all(result == temp_m))