hi,
i have read many forums trying to understand .detach()
, and although i get a intutution of what it is , i still don’t understand it completely . I dont understand what removing a tensor from computational graph implies , for ex. consider piece of code below
x = torch.tensor(([1]))
y = x**2
z = 2*y
w= z**3
z = z.detach()
x.grad.zero_()
w.backward()
print x.grad
The output came out to be 48
which is the same with or without z.detach()
. What does `z.detach() do here ? Why wasn’t it removed from the computational graph?
thanks