Understanding where gradients are stored in backward

As @smth explains here gradients non-leaf gradients are not retained by default to save memory. You can use the hooks mentioned in his post or call .retain_grad() on the particular intermediate results:

N, D_in, D_out, H = 1, 10, 1, 5
device = 'cpu'

x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

w1 = torch.randn(D_in, H, device=device, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, requires_grad=True)

x_mm1 = x.mm(w1)
x_mm1.retain_grad()
x_clamp = x_mm1.clamp(min=0)
x_clamp.retain_grad()
y_pred = x_clamp.mm(w2)

y_pred.backward()
print(w1.grad)
print(w2.grad)
print(x_clamp.grad)
print(x_mm1.grad)
1 Like