Gradient computaion quesiton

a=nn.Parameter(torch.tensor([[2.]]))
b=a*2
print(b)
d=nn.Parameter(torch.tensor(3.))
print(d)
e=b[[0],:]*d
b[0]=e#inplaced
print(e)
loss1=e.sum()
e=b[[0],:]*d
b[0]=e
print(e)
loss1+=e.sum()

loss1.backward() works well,why?the initial b[0](

tensor([[4.]], grad_fn=<MulBackward0>

) has been inplaced by e,how can this loss backward to a?

i just find the gradient to a works well,but do not know why?becuase i konw pytorch always encounter problems when using inplace operation,but this case works well,why?

Hi Bingqing!

The b[[0],:] in the line e=b[[0],:]*d creates a new tensor, rather
than just a view into b, and it is this new tensor that is stored in the
computation graph and then used for backpropagation. So, even
though b[0]=e#inplaced does modify b in place, the actual tensor
stored in the computation graph is not modified in place.

Consider:

>>> import torch
>>> print (torch.__version__)
1.10.2
>>>
>>> a = torch.nn.Parameter (torch.tensor ([[2.]]))
>>> b = a * 2
>>> c_new =  b[[0], :]   # c_new is a new tensor
>>> c_view = b[0]        # c_view is a view into b
>>> print (b)
tensor([[4.]], grad_fn=<MulBackward0>)
>>> print (c_new)
tensor([[4.]], grad_fn=<IndexBackward0>)
>>> print (c_view)
tensor([4.], grad_fn=<SelectBackward0>)
>>> b[0] = 99.9          # inplace modification of b
>>> print (b)
tensor([[99.9000]], grad_fn=<CopySlices>)
>>> print (c_new)        # c_new is not modified -- won't break backpropagation
tensor([[4.]], grad_fn=<IndexBackward0>)
>>> print (c_view)       # c_view reflects the modification -- breaks backpropagation
tensor([99.9000], grad_fn=<AsStridedBackward0>)
>>>
>>> a = torch.nn.Parameter (torch.tensor ([[2.]]))
>>> b = a * 2
>>> print (b)
tensor([[4.]], grad_fn=<MulBackward0>)
>>> d = torch.nn.Parameter (torch.tensor (3.))
>>> print (d)
Parameter containing:
tensor(3., requires_grad=True)
>>> e = b[[0], :] * d    # b[[0], :] is actually a new tensor in the computation graph
>>> b[0] = e             # inplaced, but doesn't break backpropagation
>>> print (e)
tensor([[12.]], grad_fn=<MulBackward0>)
>>> loss1 = e.sum()
>>> loss1.backward()
>>>
>>> a = torch.nn.Parameter (torch.tensor ([[2.]]))
>>> b = a * 2
>>> print (b)
tensor([[4.]], grad_fn=<MulBackward0>)
>>> d = torch.nn.Parameter (torch.tensor (3.))
>>> print (d)
Parameter containing:
tensor(3., requires_grad=True)
>>> e = b[0] * d         # b[0] is a view into b
>>> b[0] = e             # inplaced -- breaks backpropagation
>>> print (e)
tensor([12.], grad_fn=<MulBackward0>)
>>> loss1 = e.sum()
>>> loss1.backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch_install>\torch\_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 156, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Best.

K. Frank