I have to call backward two times in my code but the second backward doesn’t require anything from the first graph so i am calling model.zero_grad() before the second backward() pass but i am still getting this error. can someone please help me understand this problem ?
this is the error - RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.
---------------------------------here is the code ---------------------------------
def train(source, target, validate, alpha=5, epochs=15, model_name = “1p_semi_supervised_with_20p_teacher”, batch_size=512, model_lr=7e-4, coral_lr=1e-4, semi=None):
torch.manual_seed(42)
is_source = True
is_semi = False
if(semi!=None):
is_semi = True
model = DomainAdapter()
# model.load_state_dict(torch.load('9.pth'))
model_lr = 7e-4
coral_lr = 1e-4
print(is_semi)
# loss1 = criterion1(labelled_outputs, labelled_labels)
loss_class = nn.NLLLoss()
loss_domain = nn.NLLLoss()
criterion2 = nn.MSELoss()
dictionary={}
sourceDataLoader = DataLoader(source,batch_size=batch_size,shuffle=True)
targetDataLoader = DataLoader(target,batch_size=batch_size,shuffle=True)
if(is_semi):
semiDataLoader = DataLoader(semi,batch_size=batch_size,shuffle=True)
len_semi = len(semiDataLoader)
optimizer = torch.optim.Adam([
{'params':model.coral.parameters(),'lr':coral_lr},
{'params':model.feature.parameters()},
{'params':model.class_classifier.parameters()},
{'params':model.domain_classifier.parameters()}],lr=model_lr)
len_data_loader = min(len(sourceDataLoader),len(targetDataLoader))
for epoch in range(1,epochs+1):
model.train()
#plotdist(model,validate)
sourceIterator = iter(sourceDataLoader)
targetIterator = iter(targetDataLoader)
if(is_semi):
semiIterator = iter(semiDataLoader)
# print(source.x_train.shape)
# print(target.x_train.shape)
# coral loss
s_align = model.coral(source.x_train)
t_align = model.coral(target.x_train)
# model.coral.state_dict()
model.zero_grad()
err_coral = CORAL(s_align,t_align)
# print("Coral Loss:", err_coral.item())
err_coral.backward()
optimizer.step()
for batch in range(len_data_loader):
# print(batch)
# print(len_data_loader)
if(is_semi):
if(batch%len_semi==0):
semiIterator = iter(semiDataLoader)
model.zero_grad()
# source domain
data_source = next(sourceIterator)
s_features, s_label = data_source
s_label = s_label.long()
batch_size = len(s_label)
domain_label = torch.zeros(batch_size).long()
class_output, domain_output = model(s_features,alpha)
err_s_label = loss_class(class_output, s_label)
err_s_domain = loss_domain(domain_output, domain_label)
# target domain
data_target = next(targetIterator)
t_features, _= data_target
batch_size = len(t_features)
domain_label = torch.ones(batch_size).long()
_, domain_output = model(t_features,alpha)
err_t_domain = loss_domain(domain_output, domain_label)
if(is_source):
err = err_t_domain + err_s_domain + err_s_label
# semi domain
if(is_semi):
data_semi = next(semiIterator)
semi_features, semi_label = data_semi
# semi_label = semi_label.long()
batch_size = len(semi_label)
domain_label = torch.ones(batch_size).long()
class_output, domain_output = model(semi_features, alpha)
class_output = torch.exp(class_output)[:,1]
# err_semi_label = loss_class(class_output,semi_label)
# print(class_output)
t = 0.5
n = torch.pow(class_output, 1 / t)
d = torch.sum(n)
sharpened_arr = n / d
# print('shape of semi_features: ', semi_features.shape)
# print('shape of semi_label: ', semi_label.shape)
# print('shape of semi_class_output: ', class_output.shape)
err_semi_domain = loss_domain(domain_output, domain_label)
err_semi_class = criterion2(sharpened_arr, semi_label)
err_all = err + 10*(err_semi_class + err_semi_domain)
err_all.backward()
optimizer.step()