I know that detach()
is used for detaching a variable from the computational graph. In that context, are the following expressions x = x - torch.mean(x, dim=0).detach()
and x = x - torch.mean(x, dim=0)
equivalent? I just want to subtract the mean out, don’t want to pass gradients through the average calculation.
Solution: No, in second case you’ll have additional trainable shared term. E.g. for two element vector, you’ll have xout[0] = x[0] - (x[0]+x[1])/2, so xout[0] gradient will affect x[1].
The above discussion was closed, but I want to expand on this question a little bit.
I understand that the two cases will differ, but how exactly does adding .detach()
to the second term affect training? I assume it depends on the gradient of xout
, but I’m curious about how the parameter adjusts in each scenario:
- When subtracting the mean without
.detach()
- When subtracting the mean with
.detach()
Generally, subtracting the mean of a tensor centers it around zero, but how does using .detach()
influence this behavior?