Can not execute backward() after expand()

I am testing the difference between “expand()” and “repeat()”.There are some problems when executing the backward. I have a little trouble understanding this error message.

a = nn.Parameter(torch.FloatTensor(torch.ones(3,1)), requires_grad=True)
b = a.expand(3,4)
c = a.repeat(1,4)
print(a.data_ptr(),b.data_ptr(),c.data_ptr())
with torch.no_grad():
    a[0,0] = 3
print(b)
print(c)
d = torch.sum(3*b)
d.backward()
print(a.grad)

2152330348608 2152330348608 2152301237760
tensor([[3., 3., 3., 3.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], grad_fn=<AsStridedBackward>)
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], grad_fn=<RepeatBackward>)
Traceback (most recent call last):
File “d:\Desktop\deepul-master\flow.py”, line 232, in
d.backward()
File “D:\App\Anaconda3\lib\site-packages\torch\tensor.py”, line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File “D:\App\Anaconda3\lib\site-packages\torch\autograd_init_.py”, line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Index out of range

Hi. I will try to clarify as I understand it myself.

  1. It is not expand, there the problem is rising, if you not changing a, the code is working:
a = nn.Parameter(torch.FloatTensor(torch.ones(3,1)), requires_grad=True)
b = a.expand(3,4)
c = a.repeat(1,4)
print(a.data_ptr(),b.data_ptr(),c.data_ptr())

print(b)
print(c)
d = torch.sum(3*b)
d.backward()
print(a.grad)
  1. The error is giving a small clue (but clue) on what is the matter, it is about “unreachable”. And the problem is in this code a[0,0] = 3. You are basically making some part of a to not requires_grad and so make this part unreachable for backward computation. But if you change the code like this:
    a.data[0,0] = 3. It will change only data and not requires_grad and messing up with computational graph. And all the code will work as expected.

I believe, I may not be 100% correct in details of my explanation but the main idea should be correct.

As of the difference between expand and repeat, this post can be of some help to you.

Cheers.

Hi,

Some precisions to Alexey’s message:

  1. Indeed, the problem is not with only expand, but with expand + the inplace. And in particular because expand is a view of a. This is why you don’t see any issue if you remove the inplace or if you replace the expand with a repeat that is not a view (it allocates new memory).

  2. "it is about “unreachable” I don’t think that is. And the error in an internal bug on our end I think :confused:
    The reason for .data to “fix” the issue is because it hides the inplace op from the autograd (in a bad way) and so it looks like you do only the view for the autograd.
    In general, you should never use .data! and the with torch.no_grad() is the right way to do this!

The temporary fix here is to remove the print of b between the no_grad block and the line where you use it for the sum.

I opened an issue here that explains the exact issue and discusses how to solve it.: Printing should not have (bad) autograd side effects · Issue #49756 · pytorch/pytorch · GitHub

Thank you for your answer, I tried your code and it does work well. albanD’s addition to you also makes sense. Thanks again!

Thanks a lot for the correction. I looked into an issue and understand the problem better now. Sorry for potentially misleading answer from my side :slightly_frowning_face:

Can you please elaborate on this a bit?

Sorry for potentially misleading answer from my side :slightly_frowning_face:

No worries at all! Thanks for taking the time to look into the issue and provide a workaround!

Can you please elaborate on this a bit?

.data remains from the time where we had Variable and Tensors. And in that world, it was allowing you to get the Tensor inside the Variable (so the version without any autograd info).
Today where both are the same, it just creates a new Tensor with the autograd info from the original Tensor striped out.

The problem is that no autograd info also means no checking for inplace ops or views! So using such Tensors with other Tensors that actually do autograd can lead to silently wrong gradients!!

So we are in the process of removing it everywhere in our codebase to ensure that we don’t have silently wrong gradients. But users should stop using them as well and replace them with .detach() (that breaks the graph but keep all the essential view/inplace check) or torch.no_grad() to locally disable the autograd.

2 Likes