Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time

I checked the previous solutions for the same problem, which advise to add detach to any intermediate result obtained from implementing the mode but to no avail.
Here is a snipped of the code and some of the functions used.


Could you post a minimal, executable code snippet reproducing the error, please?
You can wrap the code into three backticks ```, which makes debugging easier since we cannot copy the code from your screenshots.

Here is the full code.

from scipy.stats import pearsonr

class ModelTrainer:
    
    def __init__(self, input_sizes, hidden_dim):
        self.input_sizes = input_sizes
        self.h_dim = hidden_dim
        
        model_nn = self.ae_model()
        self.net = model_nn
        self.optimizer = torch.optim.SGD(self.net.parameters(),lr=1e-5)
        self.train_dt_1, self.train_dt_2, self.train_dt_3 = self.generate_train_test_dts()
        self.log_interval = 100
        
    def generate_train_test_dts(self, batch_size=16):
        X_1, y_1 = torch.randn(10000, 64), torch.randn(10000, 1)
        X_2, y_2 = torch.randn(10000, 48), torch.randn(10000, 1)
        X_3, y_3 = torch.randn(10000, 64), torch.randn(10000, 1)
        
        train_dt_1 = TensorDataset(X_1, y_1)
        train_dt_2 = TensorDataset(X_2, y_2)
        train_dt_3 = TensorDataset(X_3, y_3)
        
        dt_1_loader = DataLoader(train_dt_1, batch_size=16)
        dt_2_loader = DataLoader(train_dt_2, batch_size=16)
        dt_3_loader = DataLoader(train_dt_3, batch_size=16)

        
        return dt_1_loader,dt_2_loader, dt_3_loader 

    
    def ae_model(self):
        model_nn = Corr_Net(self.input_sizes, self.h_dim)
        return model_nn
    
    def calculate_l2_loss(self, set_inputs, lambda_val =0.02):
        sum_combinations = 0.0
        input_names = list(set_inputs.keys())
        all_combinations = np.array(np.meshgrid(input_names, input_names)).T.reshape(-1, 2)
        
        for combination in all_combinations:
            input_1, input_2 = combination[0], combination[1]
            if input_1 == input_2:
                continue
            
            sum_combinations += self.calculate_correlation(set_inputs[input_1], set_inputs[input_2])
            
        return -lambda_val * sum_combinations
    
    def calculate_l1_loss(self, set_inputs, loss_fn):
        r_l1_loss, set_hidden_rep = 0.0, {}
        for i in range(len(set_inputs)):
            x_input = []
            for idx, input_batch in enumerate(set_inputs):
                if i == idx:
                    x_input.append(input_batch)
                else:
                    x_input.append(torch.zeros_like(input_batch))
                    

            temp_output = self.net(x_input)
            rel_output = temp_output[i]
            r_l1_loss += loss_fn(x_input[i], rel_output.detach())
            hidden_rep = self.net.out[i].detach()
            set_hidden_rep[f'input_{i}'] = hidden_rep
            
        return r_l1_loss, set_hidden_rep
        
    def calculate_correlation(self, h1, h2):
        h1_mean, h2_mean = torch.mean(h1), torch.mean(h2)
        h1_centered, h2_centered = torch.subtract(h1, h1_mean), torch.subtract(h2, h2_mean)
        corr_nr = torch.sum(torch.multiply(h1_centered, h2_centered))
        
        corr_dr1 = torch.sqrt(torch.square(h1_centered))
        corr_dr2 = torch.sqrt(torch.square(h2_centered))
      
        corr_dr = torch.add(torch.multiply(corr_dr1, corr_dr2), 1e-5)
        return corr
        
    def train(self, epoch):
        self.net.train()
        
        running_loss = 0.0
        
        itr_loader_2, itr_loader_3 = iter(self.train_dt_2), iter(self.train_dt_3)
        
        for i , (input_dt_1, label_dt_1) in enumerate(self.train_dt_1):
            (input_dt_2, labels_dt_2) = next(itr_loader_2)
            (input_dt_3, labels_dt_3) = next(itr_loader_3)
            
            set_inputs = [input_dt_1, input_dt_2, input_dt_3]
            
            self.optimizer.zero_grad()
            outputs = self.net(set_inputs)
            
            #l1_loss that minimizes the reconstruction error when other inputs are inexistent
            loss=nn.MSELoss()
            l1_loss, set_hidden_rep = self.calculate_l1_loss(set_inputs, loss)
            
            #l2_loss that calculates the correlation between hidden representations to encourage the hidden units 
            #of the representation to be shared between the representations
            l2_loss = self.calculate_l2_loss(set_hidden_rep, lambda_val =0.02)
            print(l1_loss, l2_loss)
            
            loss_values = l1_loss + l2_loss
            loss_values.backward()
            
            self.optimizer.step()
            
            running_loss += loss_values
            
            if i % self.log_interval == 0:
                print('[%d, %d/%d] Loss L1 %.3f, Loss L2 %.3f, Loss' %(epoch, i, len(self.train_dt_1)
                                                                       , l1_loss, l2_loss, (running_loss / ((i+1) * len(outputs[0])))))```
Thank you for your help!

Thanks, but it’s not executable as e.g. Corr_Net is undefined. Make sure you can copy/paste the posted code into a new script and directly execute it to reproduce the issue.

Here is the code for the Corr_Net:

class Corr_Net(nn.Module):
    
    def __init__(self, input_sizes, hidden_dim):
        super(Corr_Net, self).__init__()
        self.encoders = nn.ModuleList()
    
        for input_size in input_sizes:
            self.encoders.append(
                nn.Sequential(
                    nn.Linear(input_size, hidden_dim),
                    nn.ReLU()
                ))
            
        self.decoders = nn.ModuleList()
        
        for input_size in input_sizes:
            self.decoders.append(nn.Sequential(
                    nn.Linear(hidden_dim, input_size),
                    nn.ReLU()
                ))
            
    
    def forward(self, inputs):
        out = []
        for idx, enc in enumerate(self.encoders):
            out.append(enc(inputs[idx]))
            
        self.out = out
        
        common_rep = out[0]
        for i in range(1, len(out)):
            common_rep = torch.add(common_rep, out[idx])
        
        self.common_rep = common_rep
        
        reconstructed_layer = []
        
        for i in range(len(out)):
            reconstructed_layer.append(self.decoders[i](out[i]))
        
        return reconstructed_layer

m_t = ModelTrainer([64, 48, 64], 10)
for epoch in range(0, 5):
    m_t.train(epoch)

Thanks for the update.
Your losses do not depend on any parameters used in the model as they are calculated using the inputs:

            set_inputs = [input_dt_1, input_dt_2, input_dt_3]
            
            self.optimizer.zero_grad()
            outputs = self.net(set_inputs)
            
            #l1_loss that minimizes the reconstruction error when other inputs are inexistent
            loss=nn.MSELoss()
            l1_loss, set_hidden_rep = self.calculate_l1_loss(set_inputs, loss)
            
            #l2_loss that calculates the correlation between hidden representations to encourage the hidden units 
            #of the representation to be shared between the representations
            l2_loss = self.calculate_l2_loss(set_hidden_rep, lambda_val =0.02)

As you can see, l1_loss is calculated using set_inputs and l2_loss using the output of calculate_l1_loss.

Thank you for your answer.
In calculate_l1_loss, I am using the network to obtain temp_output.

Yes, but you are detaching it, so these tensors are constants:

            r_l1_loss += loss_fn(x_input[i], rel_output.detach())
            hidden_rep = self.net.out[i].detach()
1 Like

Thank you again for your answer. The issue is related to a type. The updated code for calculate_l1_loss and calculate_l2_loss is as follows:

def calculate_l2_loss(self, set_inputs, lambda_val =0.02):
        sum_combinations = 0.0
        input_names = list(set_inputs.keys())
        all_combinations = np.array(np.meshgrid(input_names, input_names)).T.reshape(-1, 2)
        
        for combination in all_combinations:
            input_1, input_2 = combination[0], combination[1]
            if input_1 == input_2:
                continue
            
            sum_combinations += self.calculate_correlation(set_inputs[input_1], set_inputs[input_2], lambda_val)
            
        return sum_combinations
    
    def calculate_l1_loss(self, set_inputs, loss_fn):
        r_l1_loss, set_hidden_rep = 0.0, {}
        for i in range(len(set_inputs)):
            x_input = []
            for idx, input_batch in enumerate(set_inputs):
                if i == idx:
                    x_input.append(input_batch)
                else:
                    x_input.append(torch.zeros_like(input_batch))
                    

            temp_output = self.net(x_input)
            rel_output = temp_output[i]
            r_l1_loss += loss_fn(x_input[i], rel_output)
            hidden_rep = self.net.out[i]
            set_hidden_rep[f'input_{i}'] = hidden_rep
            
        return r_l1_loss, set_hidden_rep
        
    def calculate_correlation(self, h1, h2, lambda_val):
        h1_mean, h2_mean = torch.mean(h1), torch.mean(h2)
        h1_centered, h2_centered = torch.subtract(h1, h1_mean), torch.subtract(h2, h2_mean)
        corr_nr = torch.sum(torch.multiply(h1_centered, h2_centered))
        
        corr_dr1 = torch.sqrt(torch.square(h1_centered))
        corr_dr2 = torch.sqrt(torch.square(h2_centered))
      
        corr_dr = torch.add(torch.multiply(corr_dr1, corr_dr2), 1e-5)
        corr = torch.multiply(torch.divide(corr_nr, corr_dr), -lambda_val)
        
        return torch.mean(corr)

Thank you for your help.