These are the variables we have,
M = 5 # this is input from user.
data = torch.rand(batch_size, 1024)
multipliers = [torch.rand(1024, 512) for _ in range(M)]
multipliers_to_multiply_each_row_of_data = [
[1,2],
[2,4],
…
…
…
[1,4],
[3,2],
]
We want to do a selective tensor contraction of data
and Multipliers
as given by multipliers_to_multiply_each_row_of_data
. So the last row [3,2]
means, we’ll need to do a tensor contraction of last row of data
with the 3rd and 2nd matrix given in multipliers
. In this case, we have two multipliers per data row, but it’s essentially a hyperparameter, let’s say, K
.
The best I could come up was this. Here I tried to execute a multiplier
just a single time.:
one_hot_multiplier_mask = F.one_hot(multipliers, num_classes=self.num_multipliers) # (T, K, M)
one_hot_multiplier_mask = one_hot_multiplier_mask.permute(2, 1, 0) # (M, K, T)
where_multiplier, where_T, where_k = torch.where(one_hot_multiplier_mask)
for midx in range(self.num_multipliers):
selection_mask1 = where_multiplier == midx
selection_mask2 = where_k[selection_mask1]
if selection_mask2.nelement() > 0:
A = data[selection_mask2] # this is the bottleneck!!
current_data = A @ multipliers[multiplier_idx]
result[selection_mask2, midx * self.multiplier_size:(midx+1) * self.multiplier_size] = current_data
Pytorch profiler indicated that the bottleneck is A = data[selection_mask2]
.
Looking at the pytorch docs, I found that indexing through a tensor actually copies the result as opposed to indexing through an integer or a slice which returns a view.
Can this be done any faster?