Regarding Detach function

Can somebody explain in simple words with illustration what does detach function do in pytorch?

There are basically two things to understand before talking about detach:

  • What is a Computational Graph ?
  • How PyTorch builds a Dynamic Computational Graph ?

I’m assuming that you know the answers to the questions above. So, detach is a method that allows you to manipulate Tensors without expanding the computational graph. For instance:

import torch

x = torch.rand(5).requires_grad_(True)
y = x ** 2 + 1

Say that you are interested in the sum of the elements of x. Then if you run print(x.sum()) PyTorch will create a new computational graph having x as leaf. But if you run print(x.detach().sum()) no additional graph will be created.

The most common use of detach I’ve seen happens when you want to transfer PyTorch’s Tensors (having requires_grad == True) to NumPy. You can’t run x.numpy() in the snippet above, you should rather run x.detach().numpy().

1 Like

Thanks @ LeviViana