Hey, thanks a lot!
I’ll check your proposed solution. In the meantime I actually came up with an alternative solution using tensors algebra.
The idea is to convert the index tensor in another tensor that multiplied by the encoding tensor gives the desired result. So, for example, let’s consider an easier case to visualize, where we have batchsize = 3, encoding dimension = 6 and sequence lenght = 5. Therefore, x now is (3, 5, 6) and let’s say that our i is of shape (3, 2, 2), with i[0] given by:
i[0] = [ [1,3], [4,5] ]
you want to convert it to a 2x5 matrix m that is non zero only in the positions defined by i[0]:
m = [ [ 0, 0.5, 0.5, 0, 0 ],
[ 0, 0, 0, 0, 1 ] ]
(note that each row needs to sum to 1 since it contains the weights for the computation of the average)
such that when you perform the matrix multiplication: m.mm(x[0]), you get the desired 2x6 matrix:
[ x[0][1:3].mean(0),
x[0][4:5].mean(0) ]
Now you just have to generalize the thing to the whole batch by using the generalized tensor product. So, you want to find the tensor M that multiplied by your x of shape (3, 5, 6) gives the desired (3, 2, 6) tensor, let’s call this y, containing the mean of the rows specified in i.
y = torch.tensordot(M, x)
Since x can be viewed as a 15x6 matrix, it’s easy to understand that M is going to be a 6x15 matrix, such that y = M*x is a 6x6 matrix, or alternatively the (3, 2, 6) tensor we were looking for. In particular, M corresponds to a (3, 3, 2, 5) tensor, i.e. a 3x3 matrix of 2x5 matrices, where each 2x5 matrix on the diagonal is built as done with m and i[0] above, and each off-diagonal element is just zero. Therefore with M defined this way you obtain the desired y simply by performing the product:
y = torch.tensordot( M, x, dims=[ [1,3], [0,1] ] )
In the following the code corresponding to the example above
In [122]: x = torch.randn(3, 5, 6)
In [123]: x
Out[123]:
tensor([[[ 0.1536, 0.9683, -0.1524, 0.0499, -1.5984, -0.6716],
[-0.6679, 0.4353, 0.1829, 1.1936, 0.7692, 0.3316],
[-0.4091, 0.2014, -1.1973, 0.4434, 0.5641, -0.5315],
[-0.4476, -0.2807, -0.9703, -0.3468, 3.5646, 0.0424],
[-0.9625, -0.4611, -0.5875, -1.4666, 0.4561, -0.6519]],
[[ 0.2110, -1.8854, -0.6089, 0.4814, 0.0223, 0.4263],
[-0.1546, -2.3762, 1.4597, 0.1886, -0.7254, 0.8121],
[ 1.2119, -0.8752, 0.7288, -1.2654, -0.7958, -0.0572],
[ 1.9370, -1.4721, 1.0472, -1.4682, 1.0570, 1.0122],
[-0.9708, -0.7558, -2.5213, -0.2033, -0.3877, -0.8172]],
[[-0.8187, -0.4423, 0.7459, 1.0440, 0.1667, -0.6885],
[ 0.1885, -0.5397, -0.2614, 0.0529, -0.0420, -0.1613],
[-1.3891, -0.6847, 0.1743, 0.5257, 0.5144, -0.6880],
[-0.2109, 1.1182, 0.3880, -0.8249, 0.1124, -0.0780],
[ 1.6060, -0.5759, -0.3814, -0.6349, -0.8468, 2.1026]]])
In [124]: m = torch.zeros(2,5)
In [125]: m[0,1:3] = 0.5
In [126]: m[1,4] = 1
In [127]: m
Out[127]:
tensor([[0.0000, 0.5000, 0.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 1.0000]])
In [130]: M = torch.zeros(3, 3, 2, 5)
In [131]: M[0,0], M[1,1], M[2,2] = m, m, m
In [132]: M
Out[132]:
tensor([[[[0.0000, 0.5000, 0.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 1.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.5000, 0.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 1.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.5000, 0.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]]])
In [133]: y = torch.tensordot(M,x,dims=[[1,3],[0,1]])
In [134]: y
Out[134]:
tensor([[[-0.5385, 0.3183, -0.5072, 0.8185, 0.6666, -0.1000],
[-0.9625, -0.4611, -0.5875, -1.4666, 0.4561, -0.6519]],
[[ 0.5286, -1.6257, 1.0942, -0.5384, -0.7606, 0.3775],
[-0.9708, -0.7558, -2.5213, -0.2033, -0.3877, -0.8172]],
[[-0.6003, -0.6122, -0.0436, 0.2893, 0.2362, -0.4246],
[ 1.6060, -0.5759, -0.3814, -0.6349, -0.8468, 2.1026]]])
The only problem here is that I would need to convert the i tensor to the new representation, but I can do that in advance out of the training routine, so it shouldn’t be a problem.
In the end, which one of the two methods do you think is gonna run faster? That’s my only concern, since this part of my pipeline was bottlenecking the rest.