How to do an efficient index_select operation over an input tensor of size [1, 32, 32, 12]
and an index tensor of size [1, 5, 2]
?
Here is my Python code:
ret = []
for batch_idx in range(N):
actions = []
for agent in range(self.n_agents):
agent_idx = obs[batch_idx][agent]
act = pi[batch_idx, agent_idx[0], agent_idx[1]]
actions.append(act)
actions = torch.stack(actions)
ret.append(actions)
ret = torch.stack(ret)
batch_idx = 1 # or 32 in training phase
self.n_agents = 5
obs
is a tensor index of shape [1, 5, 2]
, which refers to the batch_size, num_of_agents, and the location coordination x and y respectively.
pi
is a policy map, whose shape is [1, 32, 32, 12]
, which refers the batch_size, map_max_x, map_max_y, and feature_dim respectively.
the ret
is a tensor of shape [1, 5, 12]
This code is two time-wasting when the batch size become larger from 1 to 32.
How can I write an efficient torch style Code?