Hi!
Consider the following example:
n = 100
a = torch.randn(n, 1, 100)
b = torch.randn(3, 100, 101)
i = torch.randint(0, 3, size=100)
c = b[i] # shape (n, 100, 101)
d = torch.bmm(a, c) # shape (n, 1, 101)
Here, the c
matrix is allocated an memory, which becomes prohibitively expensive if n
becomes large. However, I could imagine that a CUDA kernel could be written that merges the indexing operation and the batched matrix multiplication, so that the c
matrix is never allocated.
Does such a method already exist? Or will I have to write this CUDA kernel myself?
Thanks!