 # How to do an efficient complex index_select operation using torch?

How to do an efficient index_select operation over an input tensor of size `[1, 32, 32, 12]` and an index tensor of size `[1, 5, 2]` ?

Here is my Python code:

``````ret = []
for batch_idx in range(N):
actions = []
for agent in range(self.n_agents):
agent_idx = obs[batch_idx][agent]
act = pi[batch_idx, agent_idx, agent_idx]
actions.append(act)
actions = torch.stack(actions)
ret.append(actions)
ret = torch.stack(ret)

``````

batch_idx = 1 # or 32 in training phase
self.n_agents = 5
`obs` is a tensor index of shape `[1, 5, 2]`, which refers to the batch_size, num_of_agents, and the location coordination x and y respectively.
`pi` is a policy map, whose shape is `[1, 32, 32, 12]`, which refers the batch_size, map_max_x, map_max_y, and feature_dim respectively.
the `ret` is a tensor of shape `[1, 5, 12]`

This code is two time-wasting when the batch size become larger from 1 to 32.
How can I write an efficient torch style Code?

Ok, the problem is solved.
Using a reshape operation before index_select is more efficient.

``````# the example code
pi = pi.reshape(-1, self.act_dim)
obs = obs.reshape(N, -1)

selected_agents_idx = []
for batch_idx in range(N):
batch_select_agent_idx = obs[batch_idx] + (batch_idx * self.map_size * self.map_size)
selected_agents_idx.append(batch_select_agent_idx)
selected_agents_idx = torch.cat(selected_agents_idx)

action_selected = pi.index_select(0, selected_agents_idx)
action_selected = action_selected.reshape(N, self.n_agents, -1)
``````