Gather-and-reduce in one operation


In my use-cases a very common pattern is to use gather and then reduce one dimensions.


B = 4
D = 16
W = 128
N = 1000
sketch = torch.randn(B, D, W)
codes = torch.randint(W, [B, D, N])
scores = sketch.gather(dim=-1, index=codes).min(dim=-2).values

Unfortunately it means I am creating intermediate value after-gather and before-min that scales memory-wise linearly with D and N dimensions, even though D dimensions is immediately reduced.

I can do loop-over-D-dimension and reduce into accumulator to avoid unnecessary memory explosion:

scores = torch.zeros(B, N)
for i in range(D):
    scores = torch.minimum(scores, sketch[:, i].gather(dim=-1, index=codes[:, i]))

but unfortunately it gets extremely slow.

Is it possible to somehow get best-of-two worlds? It seems like it should be doable with some kind of native gather-and-reduce operation.

To best of my knowledge it would require having custom native operation. Is that correct? If so could you share some links/snippets/code references that might be a good starting point if one would want to implement such native operation?

1 Like

Totally agree. torch.gather_reduce would be awesome to have.