When to use detach

If I have two different neural networks (parametrized by model1 and model2) and corresponding two optimizers, would the below operation using model1.parameters without detach() lead to change in its gradients? My requirement is that I want to just compute the mean squared loss between the two model parameters but update the optimizer corresponding to model1.

opt1 = torch.optim.SGD(self.model.parameters(), lr=1e-3)
opt2 = torch.optim.SGD(self.model.parameters(), lr=1e-3)

loss = (self.lamb / 2.) * ((torch.nn.utils.parameters_to_vector(self.model1.parameters()) - torch.nn.utils.parameters_to_vector(self.model2.parameters()))**2).sum()


How can I decide in general whether to use detach for any operation or not?


parameters_to_vector is differentiable and so yes gradients will flow back to both models.

In general, there are very limited cases where you need .detach() within your training function. It is most often used when you want to save the loss for logging, or save a Tensor for later inspection but you don’t need gradient information.

1 Like

Thanks @albanD for the reply.

Follow-up questions for more clarity:

You mentioned about parameters_to_vector being differentiable. So, how can I check whether a function is differentiable or not?

Also, can you be more specific about the operations in which detach is required. For instance, if I pass the model.parameters to any other function or use it like params = list(self.model1.parameters()), will these require the use of detach()?

In general, all ops in pytorch are differentiable.
The main exceptions are .detach() and with torch.no_grad. As well as functions that work with nn.Parameter that needs to remain leafs and so cannot have gradient history.

Also, can you be more specific about the operations in which detach is required

You most likely never need it actually :slight_smile:

1 Like

The first part makes sense but not the second part. Sorry if that’s annoying.

Detach is used to break the graph to mess with the gradient computation.
In 99% of the cases, you never want to do that.

The only weird cases where it can be useful are the ones I mentioned above where you want to use a Tensor that was used in a differentiable function for a function that is not expected to be differentiated. And so you can use detach() to express that. But this is a rare case.

1 Like