How to do an efficient complex index_select operation using torch?

How to do an efficient index_select operation over an input tensor of size [1, 32, 32, 12] and an index tensor of size [1, 5, 2] ?

Here is my Python code:

ret = []
for batch_idx in range(N):
    actions = []
    for agent in range(self.n_agents):
        agent_idx = obs[batch_idx][agent]
        act = pi[batch_idx, agent_idx[0], agent_idx[1]]
        actions.append(act)
    actions = torch.stack(actions)
    ret.append(actions)
ret = torch.stack(ret)

batch_idx = 1 # or 32 in training phase
self.n_agents = 5
obs is a tensor index of shape [1, 5, 2], which refers to the batch_size, num_of_agents, and the location coordination x and y respectively.
pi is a policy map, whose shape is [1, 32, 32, 12], which refers the batch_size, map_max_x, map_max_y, and feature_dim respectively.
the ret is a tensor of shape [1, 5, 12]

This code is two time-wasting when the batch size become larger from 1 to 32.
How can I write an efficient torch style Code?

Ok, the problem is solved.
Using a reshape operation before index_select is more efficient.

# the example code
        pi = pi.reshape(-1, self.act_dim)
        obs = obs.reshape(N, -1)

        selected_agents_idx = []
        for batch_idx in range(N):
            batch_select_agent_idx = obs[batch_idx] + (batch_idx * self.map_size * self.map_size)
            selected_agents_idx.append(batch_select_agent_idx)
        selected_agents_idx = torch.cat(selected_agents_idx)
        
        action_selected = pi.index_select(0, selected_agents_idx)
        action_selected = action_selected.reshape(N, self.n_agents, -1)