In my use-cases a very common pattern is to use gather and then reduce one dimensions.
B = 4 D = 16 W = 128 N = 1000 sketch = torch.randn(B, D, W) codes = torch.randint(W, [B, D, N]) scores = sketch.gather(dim=-1, index=codes).min(dim=-2).values
Unfortunately it means I am creating intermediate value after-gather and before-min that scales memory-wise linearly with D and N dimensions, even though D dimensions is immediately reduced.
I can do loop-over-D-dimension and reduce into accumulator to avoid unnecessary memory explosion:
scores = torch.zeros(B, N) for i in range(D): scores = torch.minimum(scores, sketch[:, i].gather(dim=-1, index=codes[:, i]))
but unfortunately it gets extremely slow.
Is it possible to somehow get best-of-two worlds? It seems like it should be doable with some kind of native gather-and-reduce operation.
To best of my knowledge it would require having custom native operation. Is that correct? If so could you share some links/snippets/code references that might be a good starting point if one would want to implement such native operation?