How to optimise a loop body with a large number of loops?

Here’s how I can optimise this code

a = torch.rand(12800, 100, 3)
b = torch.randint(1, 30000, (12800, 100))
c = torch.zeros(12800, 3)
for i in range(12800):
    c[i] = c[i] + a[b == i].sum(dim=0)

I tried to use the operator to speed it up, but the need to ask for the gradient of this variable prevents it from being exported using jit

Hi Xingze!

You can eliminate the for loop by materializing a large numerical (0.0, 1.0)
mask tensor and then using einsum() to compute c.

Here is an example script (reduced in size to fit into my memory):

import torch
print (torch.__version__)

_ = torch.manual_seed (2023)

a = torch.rand(6400, 100, 3)   # reduce size from 128000 to 6400 to fit in memory
b = torch.randint(1, 15000, (6400, 100))
c = torch.zeros(6400, 3)
for i in range(6400):
    c[i] = c[i] + a[b == i].sum(dim=0)

print ('c.shape:', c.shape)

mask = (b.unsqueeze (0) == torch.arange (6400).unsqueeze (-1).unsqueeze (-1)).float()   # materialize large mask tensor

print ('mask.shape:', mask.shape)

cB = torch.einsum ('mij, ijn -> mn', mask, a)

print ('cB.shape:', cB.shape)
print ('torch.allclose (c, cB):', torch.allclose (c, cB))

And here is its output:

c.shape: torch.Size([6400, 3])
mask.shape: torch.Size([6400, 6400, 100])
cB.shape: torch.Size([6400, 3])
torch.allclose (c, cB): True


K. Frank

You are right,thank you