W.detach() and then requires_grad_ again?

HI, the official document says w2=w.detach() will share the same data with w.
So I wonder what w2 would be after:

w2=w.detach()
w2.requires_grad_()

Is w2 exactly the same object with w after requires_grad again?
If not, what does the above code do?

2 Likes

It’s not the same python object as it is a new Tensor that requires grad. It will share it’s data with w though.

@albanD Thank you. But I can not understand What it would be when two tensor is different while sharing same data?

The difference is that if you use the detached tensor and for example compute a loss, the gradients of that loss will be only computed backward upto the point of detachment. However, if you use the other one, the gradients will be computed even further back.

I can explain this with an example:

>>> m1 = torch.nn.Conv2d(3, 5, 3)
>>> m2 = torch.nn.Conv2d(5, 5, 1)
>>> x = torch.randn(2, 3, 4, 4)

### making a forward pass:
>>> h1 = m1(x)
>>> h2 = h1.detach()
>>> z1 = m2(h1)
>>> z2 = m2(h2)

So in this example, the gradients of z1 will affect both m1 and m2, while the gradients of z2 only affect m2. The reason is z2 is computed from h2 and h2 which is detached from the computation graph.

3 Likes