Optimization help needed!

These are the variables we have,

M = 5   # this is input from user.
data = torch.rand(batch_size, 1024)
multipliers = [torch.rand(1024, 512) for _ in range(M)]

multipliers_to_multiply_each_row_of_data = [
	[1,2],
	[2,4],
	…
	…
	…
	[1,4],
	[3,2],
]

We want to do a selective tensor contraction of data and Multipliers as given by multipliers_to_multiply_each_row_of_data. So the last row [3,2] means, we’ll need to do a tensor contraction of last row of data with the 3rd and 2nd matrix given in multipliers. In this case, we have two multipliers per data row, but it’s essentially a hyperparameter, let’s say, K.

The best I could come up was this. Here I tried to execute a multiplier just a single time.:

one_hot_multiplier_mask = F.one_hot(multipliers, num_classes=self.num_multipliers)  # (T, K, M)
one_hot_multiplier_mask = one_hot_multiplier_mask.permute(2, 1, 0)  # (M, K, T)

where_multiplier, where_T, where_k = torch.where(one_hot_multiplier_mask)

for midx in range(self.num_multipliers):
    selection_mask1 = where_multiplier == midx
    selection_mask2 = where_k[selection_mask1]

    if selection_mask2.nelement() > 0:
        A = data[selection_mask2]  # this is the bottleneck!!
        current_data = A @ multipliers[multiplier_idx]

        result[selection_mask2, midx * self.multiplier_size:(midx+1) * self.multiplier_size] = current_data

Pytorch profiler indicated that the bottleneck is A = data[selection_mask2].
Looking at the pytorch docs, I found that indexing through a tensor actually copies the result as opposed to indexing through an integer or a slice which returns a view.

Can this be done any faster?