Selecting rows from a 3D tensor using a 3D span index

Hi all,

I was wondering if there was a smart pytorch-ic way to extract the rows of a 3D tensor based on some 3D ‘span’ index tensor.
In detail, let’s say you have a (32, 20, 768) tensor x, where 32 is the batchsize, 20 is the sequence length (sentence length for instance) and 768 the encoding dimension. You also have an index tensor i of shape (32, 3, 2), where again the first dimension is the batchsize (corresponding to x), the second dimension is the number of spans you want to extract and the last one is equal to 2 since a span is defined by two integers.
For example let’s say that i[0] is:

[ [1,2], [5,7], [15,19] ]

what I would like to extract from x[0] is:

x[0][1:2]
x[0][5:7]
x[0][15:19]

and build a new tensor y with y[0] of the form:

[ x[0][1:2].mean(0),
x[0][5:7].mean(0),
x[0][15:19].mean(0)) ]

and in general

y[s] = [ x[s][ i[s][u][0] : i[s][u][1] ].mean(0) ]

with s in range(32) and u in range(3) for each s.

I guess the solution should be somewhat similar to what discussed here, but I can’t figure out how to achieve the desired result avoiding for loops.

I do not check my code, hope it could help you…

# generate input data 
batch_size, seq, embedding, indices_len = 30, 20, 5, 3
x = torch.randn((batch_size, seq, embedding))
indices = torch.randint(0, seq - 1, (batch_size, indices_len, 2)).sort(dim=-1)[0]
indices[:, :, 1] += 1 

# Now Let's go to calculate the result 
# Step 1. Note that x[start:end] equals to cumsum[end] - cumsum[start]. 
# Unfortunately, I donot know how to scatter a 2D tensor. Let me reshape it.
cumsum_x = torch.cumsum(x, dim=1)
cumsum_x = .view(-1, embedding)  # (30x20, 5)

# Step 2. Flatten indices.
lower_index = (indices[..., 0] + torch.arange(batch_size)[:, None] * seq).view(-1)
higher_index = (indices[..., 1] + torch.arange(batch_size)[:, None] * seq).view(-1)

# Step 3, select elements
lower = cumsum_x[lower_index, :].view(batch_size, indices_len, embedding)
higher = cumsum_x[higher_index, :].view(batch_size, indices_len, embedding)

# Step 4, mean
step = indices[..., 1:2] - indices[..., 0:1]
result = (higher - lower) / step.float()

Hey, thanks a lot!
I’ll check your proposed solution. In the meantime I actually came up with an alternative solution using tensors algebra.

The idea is to convert the index tensor in another tensor that multiplied by the encoding tensor gives the desired result. So, for example, let’s consider an easier case to visualize, where we have batchsize = 3, encoding dimension = 6 and sequence lenght = 5. Therefore, x now is (3, 5, 6) and let’s say that our i is of shape (3, 2, 2), with i[0] given by:

i[0] = [ [1,3], [4,5] ]

you want to convert it to a 2x5 matrix m that is non zero only in the positions defined by i[0]:

m = [ [ 0, 0.5, 0.5, 0, 0 ],
[ 0, 0, 0, 0, 1 ] ]
(note that each row needs to sum to 1 since it contains the weights for the computation of the average)

such that when you perform the matrix multiplication: m.mm(x[0]), you get the desired 2x6 matrix:

[ x[0][1:3].mean(0),
x[0][4:5].mean(0) ]

Now you just have to generalize the thing to the whole batch by using the generalized tensor product. So, you want to find the tensor M that multiplied by your x of shape (3, 5, 6) gives the desired (3, 2, 6) tensor, let’s call this y, containing the mean of the rows specified in i.

y = torch.tensordot(M, x)

Since x can be viewed as a 15x6 matrix, it’s easy to understand that M is going to be a 6x15 matrix, such that y = M*x is a 6x6 matrix, or alternatively the (3, 2, 6) tensor we were looking for. In particular, M corresponds to a (3, 3, 2, 5) tensor, i.e. a 3x3 matrix of 2x5 matrices, where each 2x5 matrix on the diagonal is built as done with m and i[0] above, and each off-diagonal element is just zero. Therefore with M defined this way you obtain the desired y simply by performing the product:

y = torch.tensordot( M, x, dims=[ [1,3], [0,1] ] )

In the following the code corresponding to the example above

In [122]: x = torch.randn(3, 5, 6)

In [123]: x
Out[123]: 
tensor([[[ 0.1536,  0.9683, -0.1524,  0.0499, -1.5984, -0.6716],
         [-0.6679,  0.4353,  0.1829,  1.1936,  0.7692,  0.3316],
         [-0.4091,  0.2014, -1.1973,  0.4434,  0.5641, -0.5315],
         [-0.4476, -0.2807, -0.9703, -0.3468,  3.5646,  0.0424],
         [-0.9625, -0.4611, -0.5875, -1.4666,  0.4561, -0.6519]],

        [[ 0.2110, -1.8854, -0.6089,  0.4814,  0.0223,  0.4263],
         [-0.1546, -2.3762,  1.4597,  0.1886, -0.7254,  0.8121],
         [ 1.2119, -0.8752,  0.7288, -1.2654, -0.7958, -0.0572],
         [ 1.9370, -1.4721,  1.0472, -1.4682,  1.0570,  1.0122],
         [-0.9708, -0.7558, -2.5213, -0.2033, -0.3877, -0.8172]],

        [[-0.8187, -0.4423,  0.7459,  1.0440,  0.1667, -0.6885],
         [ 0.1885, -0.5397, -0.2614,  0.0529, -0.0420, -0.1613],
         [-1.3891, -0.6847,  0.1743,  0.5257,  0.5144, -0.6880],
         [-0.2109,  1.1182,  0.3880, -0.8249,  0.1124, -0.0780],
         [ 1.6060, -0.5759, -0.3814, -0.6349, -0.8468,  2.1026]]])

In [124]: m = torch.zeros(2,5)

In [125]: m[0,1:3] = 0.5

In [126]: m[1,4] = 1

In [127]: m
Out[127]: 
tensor([[0.0000, 0.5000, 0.5000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]])

In [130]: M = torch.zeros(3, 3, 2, 5)

In [131]: M[0,0], M[1,1], M[2,2] = m, m, m

In [132]: M
Out[132]: 
tensor([[[[0.0000, 0.5000, 0.5000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]],

         [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],


        [[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.5000, 0.5000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]],

         [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],


        [[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.5000, 0.5000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]]])

In [133]: y = torch.tensordot(M,x,dims=[[1,3],[0,1]])

In [134]: y
Out[134]: 
tensor([[[-0.5385,  0.3183, -0.5072,  0.8185,  0.6666, -0.1000],
         [-0.9625, -0.4611, -0.5875, -1.4666,  0.4561, -0.6519]],

        [[ 0.5286, -1.6257,  1.0942, -0.5384, -0.7606,  0.3775],
         [-0.9708, -0.7558, -2.5213, -0.2033, -0.3877, -0.8172]],

        [[-0.6003, -0.6122, -0.0436,  0.2893,  0.2362, -0.4246],
         [ 1.6060, -0.5759, -0.3814, -0.6349, -0.8468,  2.1026]]])

The only problem here is that I would need to convert the i tensor to the new representation, but I can do that in advance out of the training routine, so it shouldn’t be a problem.
In the end, which one of the two methods do you think is gonna run faster? That’s my only concern, since this part of my pipeline was bottlenecking the rest.