Backward pass second time

I have to call backward two times in my code but the second backward doesn’t require anything from the first graph so i am calling model.zero_grad() before the second backward() pass but i am still getting this error. can someone please help me understand this problem ?

this is the error - RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

---------------------------------here is the code ---------------------------------
def train(source, target, validate, alpha=5, epochs=15, model_name = “1p_semi_supervised_with_20p_teacher”, batch_size=512, model_lr=7e-4, coral_lr=1e-4, semi=None):
torch.manual_seed(42)
is_source = True
is_semi = False
if(semi!=None):
is_semi = True

model = DomainAdapter()
# model.load_state_dict(torch.load('9.pth'))
model_lr = 7e-4
coral_lr = 1e-4
print(is_semi)
# loss1 = criterion1(labelled_outputs, labelled_labels)

loss_class = nn.NLLLoss()
loss_domain = nn.NLLLoss()
criterion2 = nn.MSELoss()

dictionary={}
sourceDataLoader = DataLoader(source,batch_size=batch_size,shuffle=True)
targetDataLoader = DataLoader(target,batch_size=batch_size,shuffle=True)

if(is_semi):
    semiDataLoader = DataLoader(semi,batch_size=batch_size,shuffle=True)
    len_semi = len(semiDataLoader)

optimizer = torch.optim.Adam([
    {'params':model.coral.parameters(),'lr':coral_lr},
    {'params':model.feature.parameters()},
    {'params':model.class_classifier.parameters()},
    {'params':model.domain_classifier.parameters()}],lr=model_lr)
len_data_loader = min(len(sourceDataLoader),len(targetDataLoader))

for epoch in range(1,epochs+1):
    model.train()
    #plotdist(model,validate)
    
    sourceIterator = iter(sourceDataLoader)
    targetIterator = iter(targetDataLoader)
    
    if(is_semi):
        semiIterator = iter(semiDataLoader)

    # print(source.x_train.shape)
    # print(target.x_train.shape)
    # coral loss
    
    s_align = model.coral(source.x_train)
    t_align = model.coral(target.x_train)
    # model.coral.state_dict()
    model.zero_grad()
    err_coral = CORAL(s_align,t_align)
    # print("Coral Loss:", err_coral.item())
    err_coral.backward()
    optimizer.step()
    
    for batch in range(len_data_loader):
        # print(batch)
        # print(len_data_loader)
        if(is_semi):
            if(batch%len_semi==0):
                semiIterator = iter(semiDataLoader)
        
        model.zero_grad()
        # source domain
        data_source = next(sourceIterator)
        s_features, s_label = data_source
        s_label = s_label.long()
        batch_size = len(s_label)
        domain_label = torch.zeros(batch_size).long()
        class_output, domain_output = model(s_features,alpha)
        err_s_label = loss_class(class_output, s_label)
        err_s_domain = loss_domain(domain_output, domain_label)
        
        # target domain
        data_target = next(targetIterator)
        t_features,  _= data_target
        batch_size = len(t_features)
        domain_label = torch.ones(batch_size).long()
        _, domain_output = model(t_features,alpha)
        err_t_domain = loss_domain(domain_output, domain_label)
        
        if(is_source):
            err = err_t_domain + err_s_domain + err_s_label
        
        # semi domain
        if(is_semi):
            data_semi = next(semiIterator)
            semi_features, semi_label = data_semi
            # semi_label = semi_label.long()
            batch_size = len(semi_label)
            domain_label = torch.ones(batch_size).long()
            class_output, domain_output = model(semi_features, alpha)
            class_output = torch.exp(class_output)[:,1]
            # err_semi_label = loss_class(class_output,semi_label)
            # print(class_output)
            t = 0.5
            n = torch.pow(class_output, 1 / t)
            d = torch.sum(n)
            sharpened_arr = n / d
            
            # print('shape of semi_features: ', semi_features.shape)
            # print('shape of semi_label: ', semi_label.shape)
            # print('shape of semi_class_output: ', class_output.shape)
            
            err_semi_domain = loss_domain(domain_output, domain_label)
            err_semi_class = criterion2(sharpened_arr, semi_label)
            err_all = err + 10*(err_semi_class + err_semi_domain)
        
        err_all.backward()
        optimizer.step()

This is expected. As when you call .backward() on a tensor (say x) the first time, all the references to the saved tensors in the computation graph of x are freed by autograd - this is the aggressive memory freeing mechanism autograd relies on.

And so, a second backward call on x will not be able to calculate the gradients as the references to the saved tensors required for grad computation are already lost.

Try setting retain_graph=True as an argument to your first backward call.