These are the variables we have,

```
M = 5 # this is input from user.
data = torch.rand(batch_size, 1024)
multipliers = [torch.rand(1024, 512) for _ in range(M)]
multipliers_to_multiply_each_row_of_data = [
[1,2],
[2,4],
…
…
…
[1,4],
[3,2],
]
```

We want to do a selective tensor contraction of `data`

and `Multipliers`

as given by `multipliers_to_multiply_each_row_of_data`

. So the last row `[3,2]`

means, we’ll need to do a tensor contraction of last row of `data`

with the 3rd and 2nd matrix given in `multipliers`

. In this case, we have two multipliers per data row, but it’s essentially a hyperparameter, let’s say, `K`

.

The best I could come up was this. Here I tried to execute a `multiplier`

just a single time.:

```
one_hot_multiplier_mask = F.one_hot(multipliers, num_classes=self.num_multipliers) # (T, K, M)
one_hot_multiplier_mask = one_hot_multiplier_mask.permute(2, 1, 0) # (M, K, T)
where_multiplier, where_T, where_k = torch.where(one_hot_multiplier_mask)
for midx in range(self.num_multipliers):
selection_mask1 = where_multiplier == midx
selection_mask2 = where_k[selection_mask1]
if selection_mask2.nelement() > 0:
A = data[selection_mask2] # this is the bottleneck!!
current_data = A @ multipliers[multiplier_idx]
result[selection_mask2, midx * self.multiplier_size:(midx+1) * self.multiplier_size] = current_data
```

Pytorch profiler indicated that the bottleneck is `A = data[selection_mask2]`

.

Looking at the pytorch docs, I found that indexing through a tensor actually copies the result as opposed to indexing through an integer or a slice which returns a view.

Can this be done any faster?