Hi,
In my use-cases a very common pattern is to use gather and then reduce one dimensions.
Example:
B = 4
D = 16
W = 128
N = 1000
sketch = torch.randn(B, D, W)
codes = torch.randint(W, [B, D, N])
scores = sketch.gather(dim=-1, index=codes).min(dim=-2).values
Unfortunately it means I am creating intermediate value after-gather and before-min that scales memory-wise linearly with D and N dimensions, even though D dimensions is immediately reduced.
I can do loop-over-D-dimension and reduce into accumulator to avoid unnecessary memory explosion:
scores = torch.zeros(B, N)
for i in range(D):
scores = torch.minimum(scores, sketch[:, i].gather(dim=-1, index=codes[:, i]))
but unfortunately it gets extremely slow.
Is it possible to somehow get best-of-two worlds? It seems like it should be doable with some kind of native gather-and-reduce operation.
To best of my knowledge it would require having custom native operation. Is that correct? If so could you share some links/snippets/code references that might be a good starting point if one would want to implement such native operation?