I am trying to slice a batch tensor (batch, channels, size, size) with a batch of coordinates (batch, 2, 2). This is similar to the this topic, which does not seem to have a proper solution.
Here is a toy example:
batch = 6 channels = 5 board_size = 4 agent_size = 2 iterations = 1000 B = torch.FloatTensor(batch, channels, board_size, board_size).uniform_(-100, 100) pos = torch.LongTensor([[0, 1], [1, 3], [2, 0], [2, 2], [2, 1], [2, 2]]) # pos = torch.randint(0, board_size - agent_size, (batch, 2))
In this example ,the resulting tensor would have shape (batch, channel, agent_size, agent_size) = (6, 5, 2, 2) and would be formed by blocks (5, 2, 2) that are not aligned in the original tensor.
The problem is that while indexing accepts a multi-element tensor, slicing does not. Therefore, the solution using numpy slicing notation is not valid and triggers “TypeError: only integer tensors of a single element can be converted to an index”.
M = B[:, :, pos[:, 0]: pos[:, 0] + agent_size, pos[:, 1]: pos[:, 1] + agent_size]
One alternative would be to pass all the indices of the sliced tensor axis by axis, but that would require building one tensor per axis with the total number of elements in the final slice (see this post). While this might be simple for lower dimensions and small slices, in my toy example it would require building 4 tensors, each with 6 x 5 x 2 x 2 = 120 indices and each following a different logic.
My current solutions are:
- use a loop to traverse the batches
- Use the tensors for indexing. I do it twice to get the sliced “frame”.
#method 1: loop along batch dimension def multiSlice1(B, pos, size): s = B.shape M = torch.zeros(s, s, size, size) for i in range(B.shape): M[i] = B[i, :, pos[i, 0]: pos[i, 0] + size, pos[i, 1]: pos[i, 1] + size] return M #method2 def multiSlice2(B, pos, size): pos_row = pos[:,0] pos_row = pos_row.view(-1, 1) + torch.arange(size) pos_row = pos_row.view(pos_row.shape, 1, pos_row.shape, 1) expanse = list(B.shape) expanse = -1 expanse = -1 pos_row = pos_row.expand(expanse) M1 = torch.gather(B, 2, pos_row) pos_col = pos[:,1] pos_col = pos_col.view(-1, 1) + torch.arange(size) pos_col = pos_col.view(pos_col.shape, 1, 1, pos_col.shape) expanse = list(M1.shape) expanse = -1 expanse = -1 pos_col = pos_col.expand(expanse) M2 = torch.gather(M1, 3, pos_col) return M2
Is there any simpler solution that is more efficient?