How to do an efficient index_select operation over an input tensor of size
[1, 32, 32, 12] and an index tensor of size
[1, 5, 2] ?
Here is my Python code:
ret =  for batch_idx in range(N): actions =  for agent in range(self.n_agents): agent_idx = obs[batch_idx][agent] act = pi[batch_idx, agent_idx, agent_idx] actions.append(act) actions = torch.stack(actions) ret.append(actions) ret = torch.stack(ret)
batch_idx = 1 # or 32 in training phase
self.n_agents = 5
obs is a tensor index of shape
[1, 5, 2], which refers to the batch_size, num_of_agents, and the location coordination x and y respectively.
pi is a policy map, whose shape is
[1, 32, 32, 12], which refers the batch_size, map_max_x, map_max_y, and feature_dim respectively.
ret is a tensor of shape
[1, 5, 12]
This code is two time-wasting when the batch size become larger from 1 to 32.
How can I write an efficient torch style Code?