PyTorch (bug?): How to pass gradient through index_add?

I’m trying to backpropogate through a Tensor.index_add operation, but it breaks when index_add broadcasts vectors. Is this intended, or is this a bug in PyTorch?

#The following code does NOT throw an error:

m=torch.zeros(4)
m.requires_grad=True

i=[0,1,2,3]
v=[6,7,8,9]

i=torch.tensor(i).long()
v=torch.tensor(v).float()
v.requires_grad=True

print(m)
m=m.index_add(0,i,v)
print(m)
m.sum().backward()
print(v.grad)
print("DONE")

#The following code throws an error:

m=torch.zeros(4,4)
m.requires_grad=True

i=[0,1,2,3]
v=[6,7,8,9]

i=torch.tensor(i).long()
v=torch.tensor(v).float()
v.requires_grad=True

print(m)
m=m.index_add(0,i,v)
print(m)
m.sum().backward()
print(v.grad)
print("DONE")

That code, when run, gives the following output and error:

tensor([0., 0., 0., 0.], requires_grad=True)
tensor([6., 7., 8., 9.], grad_fn=<IndexAddBackward>)
tensor([1., 1., 1., 1.])
DONE
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], requires_grad=True)
tensor([[6., 6., 6., 6.],
        [7., 7., 7., 7.],
        [8., 8., 8., 8.],
        [9., 9., 9., 9.]], grad_fn=<IndexAddBackward>)

Strangely, it works fine - until we call the second m.sum().backward(). Before then, nothing appears to be any more wrong than the one above it.

ERROR: RuntimeError: expand(torch.FloatTensor{[4, 4]}, size=[4]): the number of sizes provided (1) must be greater or equal to the number of dimensions in the tensor (2)

Note: I also asked this question at Stack Overflow, but right now it’s unanswered

A bit more context about my use case - I need to map a bunch of pixels from one image onto another image, where some pixels overlap (and in that case I’d like their sum). It’s like refracting light - I want parts of the image that are more squished because of a deformation to be brighter. This index_add function is a critical component, and I need it to broadcast across RGB values.

This sound like a valid issue and the workaround via:

m=m.index_add(0,i,v.expand_as(m).contiguous())

seems to work. Would you mind creating an issue on GitHub and link to this thread, please?

Sure thing; I just posted the issue: PyTorch bug: Cannot pass gradient through index_add · Issue #71155 · pytorch/pytorch · GitHub
Thank you for your help!