GPU Efficient Variation of Scatter

I am trying the accomplish the operation below in a manner that does not require for-loops (maximize performance on GPU).

def add_accumulate(src, index, dest):
  B, C, H, W = src.shape
  B, H, W = ind.shape
  M, C = dest.shape 
  # All overlapping sizes above are the same

  for b in range(B): 
     for h in range(H):
       for w in range(W):
          dest[index[b,h,w], :] += src[b, :, h, w] 
  return dst

I’ve noticed that torch.Tensor.scatter_add_ does something similar, but not the same. Any idea how I can modify the function to perform this operation efficiently? I’m guessing I would have to modify the inputs. I don’t think scatter allows broadcasting.

Scatter can be used to accomplish:
dest[ index[b][c][h][w] ][c][h][w] += src[b][c][h][w]

If index is UNIQUE, try

def add_accumulate(src, index, dest):
    C = dest.shape[-1]
    dest[index.flatten(), :] += src.permute(0, 2, 3, 1).reshape(-1, C)
    return dest

Edit 1: index is not unique

Thanks, but index may have duplicate entries