Is there a better or more obvious way to do the following operation:
A = torch.randn(3, 3, 4)
temp_m, temp_am = A.max(dim=0)
print(A)
print(temp_am.shape, temp_m.shape)
print("---max----")
print(temp_m)
print("---indexed---")
indexed_max = torch.gather(A, 0, temp_am.expand_as(A))[0]
print(indexed_max)
print(torch.all(indexed_max == temp_m))
My intent is to compute the max value of one tensor and use the argmax output to index other tensors
LeviViana
(Levi Viana)
2
Maybe this is what you want:
import torch
A = torch.randn(3, 3, 4)
temp_m, temp_am = A.max(dim=0)
print(A)
print(temp_am.shape, temp_m.shape)
print("---max----")
print(temp_m)
print("---indexed---")
indexed_max = torch.gather(A, 0, temp_am.expand_as(A))[0]
print(indexed_max)
print(torch.all(indexed_max == temp_m))
### More direct way ####
idx = A.argmax(0)
result = A.gather(0, idx.unsqueeze(0))
print(result)
print(torch.all(result == temp_m))