import torch
x = torch.tensor([0.,0.,0.], requires_grad=True)
y = x * 3
print(f'grad_fn before inplace op: {y.grad_fn}')
y[:2]=100
print(f'grad_fn after inplace op: {y.grad_fn}')
y.sum().backward()
print(f'grad at x: {x.grad}')
# grad_fn before inplace op: <MulBackward0 object at 0x7fad0404e3a0>
# grad_fn after inplace op: <CopySlices object at 0x7fad0404e3a0>
# grad at x: tensor([0., 0., 3.])
As shown above, for a tensor y
that already has a grad_fn MulBackward0
, if you do inplace operation on it, then its grad_fn will be overwritten to CopySlices
. However, pytorch still manages to do backward correctly to x
. I have two questions:
-
how many tensors are actually created in the computational graph? If there’s only two (
x
andy
) then, -
how come pytorch still knows there’s a
MulBackward0
associated to tensory
, although its grad_fn is overwritten toCopySlices
? If you would say thatMulBackward0
is not lost but hidden, how to show it?