To detach or not to detach

I know that detach() is used for detaching a variable from the computational graph. In that context, are the following expressions x = x - torch.mean(x, dim=0).detach() and x = x - torch.mean(x, dim=0) equivalent? I just want to subtract the mean out, don’t want to pass gradients through the average calculation.

Solution: No, in second case you’ll have additional trainable shared term. E.g. for two element vector, you’ll have xout[0] = x[0] - (x[0]+x[1])/2, so xout[0] gradient will affect x[1].

The above discussion was closed, but I want to expand on this question a little bit.

I understand that the two cases will differ, but how exactly does adding .detach() to the second term affect training? I assume it depends on the gradient of xout, but I’m curious about how the parameter adjusts in each scenario:

  1. When subtracting the mean without .detach()
  2. When subtracting the mean with .detach()

Generally, subtracting the mean of a tensor centers it around zero, but how does using .detach() influence this behavior?

Subtracting the mean during forward also results in the gradient being normalized this way during backward. If the detach() is done, then no normalization is done during backward on the gradient.