Torch AutoGrad behaviour after torch.detach()

Hey,

Please refer to two different code snippets.

Y_prob = model(X) #Getting the output from the model
loss = criterion(Y_true, Y_prob) #Some arbitrary loss function used
.
.
.
loss.backward()

Snippet 2:

Y_prob = model(X) #Getting the output from the model
loss = criterion(Y_true, Y_prob) #Some arbitrary loss function used
Y_prob =  Y_prob.detach() #detaching the output from the graph
.
.
.
loss.backward()

I expected that the first snippet would obviously be error-free. Still, snippet 2 should have thrown an error as I am removing the model output from the graph itself, which will throw an error when I will call loss.backward() which will call the backward() of criterion and which will in-turn call for the backward of Y_prob but since Y_prob is already detached, it will throw an error as the grad_fn of the detached object will be None.

But surprisingly, both the snippets work well. What is it that I am getting wrong here?

Basic python and even coding.
You are overwritting variable called Y_prob which equals the older Y_prob but detached.
Your computations were calculated with previous one.

it’s like if you do

x=5
print(f'id x: {id(x)}')
w=[x]
print(f'id w: {id(w[0])}')
x=6
print(f'id new x: {id(x)}')
id x: 10914624
id w: 10914624
id new x: 10914656

When you overwrite x, it no longer points to the same object.
Analogously

x=tensor(5)
print(f'{X old: id(x)}')
x=x.detach()
print(f'{X new: id(x)}')

I get that the computations were performed using the older tensor, my doubt is, is that the older object is not required again by the autograd during the backward computation?

Well it is.
That you overwrite the doesn’t mean that the object disappears.
The object is still there while it’s required.
The corresponding memory will be freed once it’s no longer pointed by an alive variable in the environment.

import ctypes
a = "hello world"
a_id = id(a)
a=141
print( ctypes.cast(a_id, ctypes.py_object).value)
print(a)
hello world
141

In fact this is sometimes a source of bugs in pytorch when people drag tensors without detaching them.