Detach() for this specific case

I know that detach() is used for detaching a variable from the computational graph. In that context, are the following expressions x = x - torch.mean(x, dim=0).detach() and x = x - torch.mean(x, dim=0) equivalent? I just want to subtract the mean out, don’t want to pass gradients through the average calculation.

No, in second case you’ll have additional trainable shared term.

E.g. for two element vector, you’ll have xout[0] = x[0] - (x[0]+x[1])/2, so xout[0] gradient will affect x[1].

Hello, I want to expand on this question a little bit.

I understand that the two cases will differ, but how exactly does adding .detach() to the second term affect training? I assume it depends on the gradient of xout, but I’m curious about how the parameter adjusts in each scenario:

  1. When subtracting the mean without .detach()
  2. When subtracting the mean with .detach()

Generally, subtracting the mean of a tensor centers it around zero, but how does using .detach() influence this behavior?

The difference is that, if you don’t detach before the mean operation, the upcoming gradients will backpropagate through the mean operation and x itself, during the training process. It is up to you whether that’s something you want.

Thank you for your response.

This is what I think:
When we don’t use detach, the network will likely “learn” to adjust the specified tensor toward a zero mean. On the other hand, when we do use detach, the tensor should learn its parameters independently of its mean.

I’d appreciate any further insights, as I’m still not entirely sure how the behavior will play out in either case. I’m diving into this problem a bit deeply, because my network is senstive to the statistical properties of the specified tensors.