Error when connecting multiple networks

I am trying to Implement a neural network of the following form:

Stage 1 - Inputs of shape (10) are fed into two feed forward NNs which I will name A, B respectively. A. B both have output shape (10). Note A, B have the same structure but I want their parameters to be trained differently.

Stage 2 - Given an input x_0, let a_0, b_0 = A[x], B[x] (the outputs of the networks A, B when x is fed in). I then perform a differentiable calculation including a_0 and b_0 to yield a new tensor x_1 of shape (10).

Stage 3 - Using x_1 as a new input, I repeat (2) n-times to finally calculate x_n.

Stage 4 - I use x_n as the input for a third feed forward NN which I will name C. C has input shape (10) and output shape (1).

I only update the parameters after stage 4 is completed. I have been getting some errors in my implementation of this network. The main error I have been receiving is

‘Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.’

I am not aware of the ‘proper’ way to implement this is pytorch (I am fairy new in general). Below I will post my code. I’d be very grateful if anyone could help me or offer some general guidance about implementing this sort of network. For the record the number n is replaced by L in the code, and A, B, C are ‘reward’, ‘penalty’, ‘g_mod’ respectively.

import torch
from torch import nn
from torch.nn import functional as F


device = "cuda" if torch.cuda.is_available() else "cpu"
# NN Modules

class RP(nn.Module):
    def __init__(self, layer_size: int, num_layers: int ):
        super(RP, self).__init__()
        self.layer_size = layer_size
        self.num_layers = num_layers
        
        self.ReLU = nn.ReLU()
        
        # Hidden layers
        hidden_layer_list = nn.ModuleList()
        hidden_layer_list.append(nn.Linear(10, layer_size))
        for _ in range(num_layers-1):
            hidden_layer_list.append(nn.Linear(layer_size, layer_size))
            hidden_layer_list.append(nn.ReLU())
        
        self.hidden = nn.Sequential(*hidden_layer_list)
            
        # Output layer
        self.out = nn.Sequential(
            nn.Linear(layer_size, 10),
            nn.Sigmoid(),
        )
    
    def forward(self, x1, x2):
        """
        Args:
            x1: tensor [w_i^(l)]_{i \in T1}
            x2: tensor [w_i^(l)]_{i \in T2}

        Returns:
            tuple (r1, r2) where:
                r1:tensor [R_{AB, i}^(l)]_{i \in T1} (replace R with P respectively)
                r2: tensor [R_{AB, i}^(l)]_{i \in T2} (replace R with P respectively)
        """
        x = torch.cat((x1, x2), dim=1) # Concatenate into input 
        x = self.ReLU(x) # Apply first ReLU
        x = self.hidden(x) # Hidden layers
        x = self.out(x) # Out
        
        return (x[:, 0:5], x[:, 5:10])
        
        
class G(nn.Module):
    def __init__(self, layer_size: int, num_layers: int ):
        super(G, self).__init__()
        self.layer_size = layer_size
        self.num_layers = num_layers
        
        self.ReLU = nn.ReLU()
        
        # Hidden layers
        hidden_layer_list = nn.ModuleList()
        hidden_layer_list.append(nn.Linear(10, layer_size))
        for _ in range(num_layers-1):
            hidden_layer_list.append(nn.Linear(layer_size, layer_size))
            hidden_layer_list.append(nn.ReLU())
        
        self.hidden = nn.Sequential(*hidden_layer_list)
            
        # Output layer
        self.out = nn.Sequential(
            nn.Linear(layer_size, 1),
            nn.Sigmoid(),
        )
    
    def forward(self, x1, x2):
        """
        Args:
            x1: tensor [w_i^(l)]_{i \in T1}
            x2: tensor [w_i^(l)]_{i \in T2}

        Returns:
            tensor \hat{y}_{T1,T2}
        """
        x = torch.cat((x1, x2), dim=1) # Concatenate into input 
        x = self.ReLU(x) # Apply first ReLU
        x = self.hidden(x) # Hidden layers
        x = self.out(x) 
        
        return x
# Initialising modules and hyperparameters

# Hyperparameters

L = 10 # Number of times to iterate R and P module
rp_layer_size = 7*10 # Hidden layer size for R and P modules
rp_num_layers = 4 # Number of hidden layers for R and P modules
g_layer_size = 9*10 # Hidden layer size for G module
g_num_layers = 4 # Number of hidden layers for G module

batch_size = 32 # Batch size
reg_const = 1e-2 # Regularisation constant

# Modules

reward = RP(rp_layer_size, rp_num_layers).to(device)
penality = RP(rp_layer_size, rp_num_layers).to(device)
g_mod = G(g_layer_size, g_num_layers).to(device)
W = torch.randn((num_champs)).to(device)
W.requires_grad = False

print(reward)
print(g_mod)

# Mini-batches

batch_inputs, batch_labels = torch.split(Xtr, batch_size), torch.split(Ytr, batch_size)

# Training

epochs = 50
learning_rate = 0.01

tr_lossg = []
te_lossg = []
te_predg = []

# Initialise
loss_f = torch.nn.BCELoss()
optimiser_R = torch.optim.Adam(reward.parameters(), lr=learning_rate, weight_decay=reg_const)
optimiser_P = torch.optim.Adam(penality.parameters(), lr=learning_rate, weight_decay=reg_const)
optimiser_G = torch.optim.Adam(g_mod.parameters(), lr=learning_rate, weight_decay=reg_const)

@torch.no_grad()
def te_loss(X_1, X_2, Y):
    pred = g_mod.forward(W[X_1], W[X_2])
    return loss_f(pred, Y).item()


# Train loop
for i in range(epochs):
    for j in range(len(batch_inputs)):

        # Load batch
        X_b = batch_inputs[j]
        X1_b = X_b[:, 0:5],
        X2_b = X_b[:, 5:10],
        Y_b = batch_labels[j]
        
        # Forward pass
        
        # Step (b) NEEDS TO BE TESTED
        for k in range(L):

            R, P = reward.forward(W[X1_b], W[X2_b]), penality.forward(W[X1_b], W[X2_b])
            
            # Calculate updates (THIS IS THE 'DIFFERENTIABLE CALCULATION')
            temp1 = (Y_b * R[0]) - ((1 - Y_b) * P[0])
            temp2 = ((1 - Y_b) * R[1]) - (Y_b * P[1])
            sums_for_matches = torch.concat((temp1, temp2), dim=1)
            
            # Update W (loop over the matches)
            for l in range(batch_size): 
                W[X_b[l]] += sums_for_matches[l]

            # DIFFERENTIABLE CALCULATION ENDS
                
            # Normalise
            W = F.normalize(W, dim=0)
            
        # Step (c)
        pred = g_mod.forward(W[X1_b], W[X2_b])

        # Backward pass
        
        optimiser_R.zero_grad()
        optimiser_P.zero_grad()
        optimiser_G.zero_grad()

        loss = loss_f(pred, Y_b)    
        loss.backward()
        
        optimiser_R.step()
        optimiser_P.step()
        optimiser_G.step()

        # Tracking data
        tr_lossg.append(loss.item())
        te_lossg.append(te_loss(Xte[:, 0:5], Xte[:, 5:10], Yte)) 
        
    # Reporting status per epoch
    print(f'{i+1}/{epochs} complete ({round(((i+1)/epochs)*100)}%).',
          f'Train loss: {round(tr_lossg[-1], 4)}. Test loss: {round(te_lossg[-1], 4)}.', end='\r')

Hi,
As for this code, I cannot see a second .backward() call, so the error you mentioned cannot be reproduced. (That error occurs when multiple backward calls are made for a same graph)
Is there any other error you need help with?

I tried simulating your code on my end, and it works just fine.

If there’s any error in the code you posted and need help with, please provide dimensions for tensors like num_champs, Xtr etc. so that I can reproduce it on my end.

Best,
S

The second backwards() call happens after the first iteration (i.e. when the second batch of the first epoch has .backwards() called). I have actually managed to fix the bug by including W = W.detach() before each batch. I’m guessing that the issue was that each batch’s grad history was linked through W. The network is still not working correctly but I think it’s down to a design issue at the moment. Thanks for your help.

Louis

I see.

Just in case if this helps:
This error occurs when you try to backpropagate through a graph a second time (after the first .backward() call).
The reason is that PyTorch aggressively frees the memory which means that as soon as . backward is called all the references to the saved tensors (that are required for gradient computation) are freed.

Since the references to the saved tensors are freed, gradient calculation cannot be performed and so a second call to that gives an error.

Specifying retain_graph=True causes those saved tensors to retain. But this comes with its own caveats. Check out some good threads on the forums on when and how retain_graph=True should be used.

detach() returns a new tensor that has its requires_grad attribute as False and is detached from the computation graph. This is probably the reason it helped you solve the error.