gengxingze
(Gengxingze)
October 13, 2023, 2:32pm
1
Here’s how I can optimise this code

```
a = torch.rand(12800, 100, 3)
b = torch.randint(1, 30000, (12800, 100))
c = torch.zeros(12800, 3)
for i in range(12800):
c[i] = c[i] + a[b == i].sum(dim=0)
```

I tried to use the operator to speed it up, but the need to ask for the gradient of this variable prevents it from being exported using jit

KFrank
(K. Frank)
October 14, 2023, 11:24pm
2
Hi Xingze!

gengxingze:

```
a = torch.rand(12800, 100, 3)
b = torch.randint(1, 30000, (12800, 100))
c = torch.zeros(12800, 3)
for i in range(12800):
c[i] = c[i] + a[b == i].sum(dim=0)
```

You can eliminate the for loop by materializing a large numerical (0.0, 1.0)
mask tensor and then using `einsum()`

to compute `c`

.

Here is an example script (reduced in size to fit into my memory):

```
import torch
print (torch.__version__)
_ = torch.manual_seed (2023)
a = torch.rand(6400, 100, 3) # reduce size from 128000 to 6400 to fit in memory
b = torch.randint(1, 15000, (6400, 100))
c = torch.zeros(6400, 3)
for i in range(6400):
c[i] = c[i] + a[b == i].sum(dim=0)
print ('c.shape:', c.shape)
mask = (b.unsqueeze (0) == torch.arange (6400).unsqueeze (-1).unsqueeze (-1)).float() # materialize large mask tensor
print ('mask.shape:', mask.shape)
cB = torch.einsum ('mij, ijn -> mn', mask, a)
print ('cB.shape:', cB.shape)
print ('torch.allclose (c, cB):', torch.allclose (c, cB))
```

And here is its output:

```
2.1.0
c.shape: torch.Size([6400, 3])
mask.shape: torch.Size([6400, 6400, 100])
cB.shape: torch.Size([6400, 3])
torch.allclose (c, cB): True
```

Best.

K. Frank