Hi there,
My question is as follows:
If I have a tensor ‘a’, which comes out as intermediate output of network, and requires_grad= True,
then I pass ‘a.detach()’ as input argument to function like do_sth(a.detach()), and does sth. else. However, after passing into the function and back in the main flow, ‘a’ will further generate a loss and requires loss.backward(). The code is roughly as below:
a = net(input)
do_sth(a.detach())
loss = L1loss(a, label)
loss.backward()
Will the weights of the network generating ‘a’ still get backpropagated when do the loss.backward()?
Hi,
Yes!
As a.detach()
returns a new tensor rather than modifying a
in-place. As long as the computation graph of loss
doesn’t break, you should be fine updating the parameters of net
.
So, make sure nothing that breaks the graph happens inside your function do_sth
.
2 Likes