I’m trying to understand what tensors does autograd free after calling .backward(), in the below code, case 1 works as expected, but case 3 doesn’t work unless adding create_graph=True as case 2. This looks odd to me because the gradient is supposed to be computed over model1’s parameters but it seems that model2’s parameters are freed or something like that happens.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class A(nn.Module):
def __init__(self):
super(A, self).__init__()
self.layer1 = nn.Linear(4, 20)
self.layer2 = nn.Linear(20, 1)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x
class B(nn.Module):
def __init__(self):
super(B, self).__init__()
self.layer1 = nn.Linear(4, 20)
self.layer2 = nn.Linear(20, 1)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x
# case 1
if __name__ == '__main__':
model1 = A()
opt1 = torch.optim.Adam(model1.parameters(), lr=1e-3)
model2 = B()
data = torch.randn(100, 4)
label = torch.randn(100, 1)
for epoch in range(10):
print("{}th backward...".format(epoch))
extra = model2(data)
loss = torch.mean((extra - label - model1(data)).pow(2))
opt1.zero_grad()
loss.backward()
opt1.step()
# case 2
if __name__ == '__main__':
model1 = A()
opt1 = torch.optim.Adam(model1.parameters(), lr=1e-3)
model2 = B()
data = torch.randn(100, 4)
label = torch.randn(100, 1)
extra = model2(data).detach()
for epoch in range(10):
print("{}th backward...".format(epoch))
loss = torch.mean((extra - label - model1(data)).pow(2))
opt1.zero_grad()
loss.backward()
opt1.step()
# case 3
if __name__ == '__main__':
model1 = A()
opt1 = torch.optim.Adam(model1.parameters(), lr=1e-3)
model2 = B()
data = torch.randn(100, 4)
label = torch.randn(100, 1)
extra = model2(data)
for epoch in range(10):
print("{}th backward...".format(epoch))
loss = torch.mean((extra - label - model1(data)).pow(2))
opt1.zero_grad()
loss.backward()
opt1.step()