Can somebody explain in simple words with illustration what does detach function do in pytorch?
There are basically two things to understand before talking about detach:
- What is a Computational Graph ?
- How PyTorch builds a Dynamic Computational Graph ?
I’m assuming that you know the answers to the questions above. So, detach is a method that allows you to manipulate Tensors without expanding the computational graph. For instance:
import torch x = torch.rand(5).requires_grad_(True) y = x ** 2 + 1
Say that you are interested in the sum of the elements of x. Then if you run
print(x.sum()) PyTorch will create a new computational graph having x as leaf. But if you run
print(x.detach().sum()) no additional graph will be created.
The most common use of
detach I’ve seen happens when you want to transfer PyTorch’s Tensors (having
requires_grad == True) to NumPy. You can’t run
x.numpy() in the snippet above, you should rather run
Thanks @ LeviViana