Consider the following example:
n = 100 a = torch.randn(n, 1, 100) b = torch.randn(3, 100, 101) i = torch.randint(0, 3, size=100) c = b[i] # shape (n, 100, 101) d = torch.bmm(a, c) # shape (n, 1, 101)
c matrix is allocated an memory, which becomes prohibitively expensive if
n becomes large. However, I could imagine that a CUDA kernel could be written that merges the indexing operation and the batched matrix multiplication, so that the
c matrix is never allocated.
Does such a method already exist? Or will I have to write this CUDA kernel myself?