To complete my understanding, based on what @ptrblck said: loss.retain_grad()
initializes loss.grad
to None
, not 1
like I said above. It is the subsequent call to loss.backward()
that sets grad
to 1
. This is analogous to what happens when we specify requires_grad=True
: the grad
value is initially set to None
, and is set to numerical values after a call to backward()
.
Here is a mini-tutorial in code that I wrote to get my head around these concepts:
import torch
from torch import nn
model = nn.Linear(1, 1)
print("#"*7)
print("Case 1")
print("#"*7)
# Case 1: No tensor has the grad attribute, none of the learnt gradients
# is saved
# This is a leaf, but has no grad attribute
x = torch.randn(1, 1)
print(x)
print(x.is_leaf) # True
print(x.grad) # None
# Forward pass, compute loss
out = model(x)
loss = out**2
# The following will print None, since we haven't computed any gradients
# yet. (Also because neither variable has a grad attribute!)
print(loss.grad) # None
print(x.grad) # None
# This computes the gradients of loss w.r.t the leaf nodes
# (which is just x). Since loss is a 1-element tensor, this method
# implicitly sets loss.grad to torch.ones_like(loss).
# But this setting is temporary, and is not saved because loss is
# not a leaf. So loss.grad remains at None after this call.
loss.backward()
# The following will print None since, while we did compute the
# gradients of loss w.r.t x, neither loss nor x has a grad attribute
print(loss.grad) # None
print(x.grad) # None
print("")
print("*"*7)
print("Case 2")
print("*"*7)
# Case 2: The one leaf tensor has the grad attribute, so its
# learnt gradient is saved
# This is a leaf, and has a grad attribute because we told so.
# This grad attribute is initialized to None.
y = torch.randn(1, 1, requires_grad=True)
print(y)
print(y.is_leaf) # True
print(y.grad) # None
# Forward pass, compute loss
out = model(y)
loss = out**2
# The following will print None, since we haven't computed any gradients
# yet
print(loss.grad) # None
print(y.grad) # None
# This computes the gradients of loss w.r.t the leaf nodes
# (which is just y). Since loss is a 1-element tensor, this method
# implicitly sets loss.grad to torch.ones_like(loss).
# But this setting is temporary, and is not saved because loss is not
# a leaf. So loss.grad remains at None after this call.
loss.backward()
# The following will print None since, while we did compute the gradients
# of loss w.r.t y,
# loss has no grad attribute
print(loss.grad) # None
# This will print a [1, 1] float tensor, which is the gradient of
# loss w.r.t y
print(y.grad)
print("")
print("*"*7)
print("Case 3")
print("*"*7)
# Case 3: Both the leaf tensor and the loss (which is not a leaf) have
# grad attributes,
# so their learnt gradients are saved
# This is a leaf, and has a grad attribute because we told so.
# This grad attribute is initialized to None.
z = torch.randn(1, 1, requires_grad=True)
print(z)
print(z.is_leaf) # True
print(z.grad) # None
# Forward pass, compute loss
out = model(z)
loss = out**2
# Enable the grad attribute for the non-leaf tensor loss.
loss.retain_grad()
# The following will print None, since we haven't computed any gradients
# yet
print(loss.grad) # None
print(z.grad) # None
# This computes the gradients of loss w.r.t the leaf nodes
# (which is just z). Since loss is a 1-element tensor, this method
# implicitly sets loss.grad to torch.ones_like(loss).
# Since we enabled the grad attribute of loss, this setting is saved even
# after the call returns.
loss.backward()
# The following will print a [1, 1] float tensor whose element is 1.0.
# This is the value which was implicitly set by the call to loss.backward()
print(loss.grad) # [[1.0]]
# This will print a [1, 1] float tensor, which is the gradient of loss
# w.r.t z
print(z.grad)
# Case 4: Same as Case 3, except we explicitly set the grad attribute
# of loss to a large value.
# Observe that the grad of the leaf changes proportionately.
print("")
print("*"*7)
print("Case 4")
print("*"*7)
w = torch.randn(1, 1, requires_grad=True)
print(w)
print(w.is_leaf)
print(w.grad)
out = model(w)
loss = out**2
loss.retain_grad()
print(loss.grad)
print(w.grad)
loss.backward(torch.tensor(1e5).view_as(loss))
print(loss.grad)
print(w.grad)