I was trying to understand how pytorch decides what operation and which not to include in the computation graph that eventually does backprop.
For the simplest example I could think, consider a (maybe ugly) manual implementation of SGD:
for i in range(nb_iter):
# Forward pass: compute predicted Y using operations on Variables
batch_xs, batch_ys = get_batch(X_train,Y_train,M,dtype) # [M, D], [M, 1]
## FORWARD PASS
y_pred = mdl.forward(batch_xs)
## LOSS
batch_loss = (1/M)*(y_pred - batch_ys).pow(2).sum()
## BACKARD PASS
batch_loss.backward() # Use autograd to compute the backward pass. Now W will have gradients
## SGD update
for W in mdl.parameters():
delta = eta*W.grad.data
W.data.copy_(W.data - delta)
## Manually zero the gradients after updating weights
mdl.zero_grad()
I guess my question is that of course batch_loss
is a function of the mdl
. But now the mdl
has had a new operation done on it in particular w - eta*loss.grad(w)
, how is it that when we call batch_loss.backward()
the next time in the loop, how does it know to NOT include that change in the parameters in the computation graph?
The reason I find it confusing is because I was told in the pytroch forum the following:
if the loss is a result of Op will the gradients be computed using Op. Otherwise, Op wont have connection to the gradients computed.
in the case Op
is the gradient update. Since it obviously affects the batch_loss
, how come its excluded from the computation graph? Shouldn’t it be included? I know that would be incorrect SGD but it should at least from the specs of pytorch or the model of pytorch I have in my mind. Obviously, I am wrong, thus my question is:
where did I go wrong? What did I misunderstand?
In particular it seems that the important question that I need to understand is, when are NEW computation graphs created when we loop? How is it that the loss is a new computation block and not re-used accidentally?