Combine matrix multiplication and indexing to reduce memory usage


Consider the following example:

n = 100
a = torch.randn(n, 1, 100)
b = torch.randn(3, 100, 101)
i = torch.randint(0, 3, size=100)
c = b[i]  # shape (n, 100, 101)
d = torch.bmm(a, c)  # shape (n, 1, 101)

Here, the c matrix is allocated an memory, which becomes prohibitively expensive if n becomes large. However, I could imagine that a CUDA kernel could be written that merges the indexing operation and the batched matrix multiplication, so that the c matrix is never allocated.

Does such a method already exist? Or will I have to write this CUDA kernel myself?