I have an original code as using torch.gather
#outputs_x size of 68x64 and max_ids size of 68x1
outputs_x_select = torch.gather(outputs_x, 1, max_ids) #68x1
outputs_x_select = outputs_x_select.squeeze(1) # 68
I want to find an alternative way that does not use torch.gather function. Currently, I found a way that use index_select such as
#outputs_x size of 68x64 and max_ids size of 68
max_ids = max_ids.squeeze()
outputs_x_select = torch.index_select(outputs_x, 1, max_ids)
outputs_x_select = outputs_x_select[range(len(max_ids)), range(len(max_ids))]
The behavior of index_select is that we will select a column/row for each index in indices. It is memory consuming than gather. Do we have any better way than my way to perform gather behavior?