Concerning .detach() as input argument to a function

Hi there,
My question is as follows:
If I have a tensor ‘a’, which comes out as intermediate output of network, and requires_grad= True,
then I pass ‘a.detach()’ as input argument to function like do_sth(a.detach()), and does sth. else. However, after passing into the function and back in the main flow, ‘a’ will further generate a loss and requires loss.backward(). The code is roughly as below:
a = net(input)
do_sth(a.detach())
loss = L1loss(a, label)
loss.backward()
Will the weights of the network generating ‘a’ still get backpropagated when do the loss.backward()?

Hi,
Yes!
As a.detach() returns a new tensor rather than modifying a in-place. As long as the computation graph of loss doesn’t break, you should be fine updating the parameters of net.

So, make sure nothing that breaks the graph happens inside your function do_sth.

2 Likes