Can somebody explain in simple words with illustration what does detach function do in pytorch?
There are basically two things to understand before talking about detach:
- What is a Computational Graph ?
- How PyTorch builds a Dynamic Computational Graph ?
I’m assuming that you know the answers to the questions above. So, detach is a method that allows you to manipulate Tensors without expanding the computational graph. For instance:
import torch
x = torch.rand(5).requires_grad_(True)
y = x ** 2 + 1
Say that you are interested in the sum of the elements of x. Then if you run print(x.sum())
PyTorch will create a new computational graph having x as leaf. But if you run print(x.detach().sum())
no additional graph will be created.
The most common use of detach
I’ve seen happens when you want to transfer PyTorch’s Tensors (having requires_grad == True
) to NumPy. You can’t run x.numpy()
in the snippet above, you should rather run x.detach().numpy()
.
1 Like