I am trying the accomplish the operation below in a manner that does not require for-loops (maximize performance on GPU).
def add_accumulate(src, index, dest):
B, C, H, W = src.shape
B, H, W = ind.shape
M, C = dest.shape
# All overlapping sizes above are the same
for b in range(B):
for h in range(H):
for w in range(W):
dest[index[b,h,w], :] += src[b, :, h, w]
return dst
I’ve noticed that torch.Tensor.scatter_add_ does something similar, but not the same. Any idea how I can modify the function to perform this operation efficiently? I’m guessing I would have to modify the inputs. I don’t think scatter allows broadcasting.
Scatter can be used to accomplish:
dest[ index[b][c][h][w] ][c][h][w] += src[b][c][h][w]