How to combine two pytorch neworks into one - specifically just adding a sequential end to the end of a conv network?

Please let me know if this it not appropriate question.

I am following this example for building a pytorch model (specifically the end of the colab - the graph-level task section).

So I have these functions:

device = torch.device('cuda')
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
train_dataset = dataset #just for testing
val_dataset = dataset
test_dataset = dataset
graph_train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) 
graph_val_loader = DataLoader(val_dataset, batch_size=8) 



class GNNModel(nn.Module)
  
    def __init__(self, c_in, c_hidden, c_out, num_layers, activation_function, optimizer_name, learning_rate, dp_rate_linear,layer_name="GCN", **kwargs):
        super().__init__()
        gnn_layer = gnn_layer_by_name[layer_name]
        
        layers = []
        activation_function = eval(activation_function) ##not great to use
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers-1):
            layers += [
                gnn_layer(in_channels=in_channels, 
                          out_channels=out_channels,
                          **kwargs),
                activation_function,
                nn.Dropout(p=dp_rate_linear)
            ]
            in_channels = c_hidden
        layers += [gnn_layer(in_channels=in_channels, 
                             out_channels=c_out,
                             **kwargs)]
        self.layers = nn.ModuleList(layers)
    
    def forward(self, x, edge_index):
        for l in self.layers:
            if isinstance(l, geom_nn.MessagePassing):
                x = l(x, edge_index)
            else:
                x = l(x)
        return x


class GraphGNNModel(nn.Module):
    
    def __init__(self, c_in, c_hidden, c_out, dp_rate_linear,**kwargs):
        super().__init__()
        self.GNN = GNNModel(c_in=c_in, 
                            c_hidden=c_hidden, 
                            c_out=c_hidden,
                            dp_rate_linear = dp_rate_linear, 
                            **kwargs)
        self.head = nn.Sequential(
            nn.Dropout(p=dp_rate_linear),
            nn.Linear(c_hidden, c_out)
        )
    def forward(self, x, edge_index, batch_idx):
        x = self.GNN(x, edge_index)
        x = geom_nn.global_mean_pool(x, batch_idx) 
        x = self.head(x)
        return x

As you can see, I really don’t need GNNModel and GraphGNNModel to be two separate functions, the second function is just adding a sequential layer to the end of the first function.

I tried combining the functions by doing:

class GNNModel(nn.Module):
    
    def __init__(self, c_in, c_hidden, c_out, num_layers, activation_function, optimizer_name, learning_rate, dp_rate_linear,layer_name="GCN" ,**kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            num_layers - Number of "hidden" graph layers
            layer_name - String of the graph layer to use
            dp_rate_linear - Dropout rate to apply throughout the network
            kwargs - Additional arguments for the graph layer (e.g. number of heads for GAT; i'm not using gat here)
            activation_function - Activation function
        """

        super().__init__()
        gnn_layer = gnn_layer_by_name[layer_name]
        
        layers = []

        activation_function = eval(activation_function) ##not great to use
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers-1):
            layers += [
                gnn_layer(in_channels=in_channels, 
                          out_channels=out_channels,
                          **kwargs),
                activation_function,
                nn.Dropout(p=dp_rate_linear)
            ]
            in_channels = c_hidden
        layers += [gnn_layer(in_channels=in_channels, 
                             out_channels=c_out,
                             **kwargs)]
        self.layers = nn.ModuleList(layers)
    

        self.head = nn.Sequential(
            nn.Dropout(p=dp_rate_linear),
            nn.Linear(c_hidden, c_out)
        )

    def forward(self, x, edge_index):
        for l in self.layers:
            if isinstance(l, geom_nn.MessagePassing): #passing data between conv
                x = l(x, edge_index) #what is this
            else:
                x = l(x)

        x = self.GNN(x, edge_index)
        x = geom_nn.global_mean_pool(x, batch_idx) 
        x = self.head(x)
        return x

And then I call my function with:

class GraphLevelGNN(pl.LightningModule):
    """
    Aim: To combine all the bits of the network together needed for training.
    See #see https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html; below is a standard set up for using pl lightning modules.
    """
    def __init__(self,**model_kwargs):
        super().__init__()


       # Saving hyperparameters
        self.save_hyperparameters()
        self.model = GNNModel(**model_kwargs) 
        print(self)
        self.loss_module = nn.BCEWithLogitsLoss() if self.hparams.c_out == 1 else nn.CrossEntropyLoss()
        self.optimizer_name = model_kwargs['optimizer_name']
        self.learning_rate = model_kwargs['learning_rate']


    def forward(self, data, mode="train"):
        print('**')
        x, edge_index, batch_idx = data.x, data.edge_index, data.batch
        x = self.model(x, edge_index, batch_idx)
        x = x.squeeze(dim=-1)


        if self.hparams.c_out == 1:
            preds = (x > 0).float() #assigning a class of 0 or 1 based on whether float is <>0
            data.y = data.y.float()

        else:
            preds = x.argmax(dim=-1)

        loss = self.loss_module(x, data.y.float())
        acc = (preds == data.y).sum().float() / preds.shape[0]

        data.y = data.y.int()
        preds = preds.int()

        f1 = BinaryF1Score().to(device) 
        f1_score = f1(preds,data.y).to(device)

        precision = BinaryPrecision().to(device)
        precision_score=precision(preds,data.y).to(device)

        recall = BinaryRecall().to(device)
        recall_score=recall(preds,data.y).to(device)

        return loss, acc, f1_score,precision_score, recall_score,preds


    def configure_optimizers(self):
        learning_rate = self.learning_rate

        if self.optimizer_name == 'SGD':
            optimizer = optim.SGD(self.parameters(),lr=learning_rate)

        elif self.optimizer_name == 'NAdam':
            optimizer = optim.NAdam(self.parameters(), lr=learning_rate)   
        
        elif self.optimizer_name == 'Adam':
            optimizer = optim.Adam(self.parameters(), lr=learning_rate)

        elif self.optimizer_name == 'ASGD':
            optimizer = optim.Adam(self.parameters(), lr=learning_rate) 

        elif self.optimizer_name == 'RMSProp':
            optimizer = optim.Adam(self.parameters(), lr=learning_rate) 
        
        elif self.optimizer_name == 'LBFGS':
            optimizer = optim.Adam(self.parameters(), lr=learning_rate) 

        elif self.optimizer_name == 'AdamW':
            optimizer = optim.Adam(self.parameters(), lr=learning_rate) 

        elif self.optimizer_name == 'Adadelta':
            optimizer = optim.Adam(self.parameters(), lr=learning_rate) 


        return optimizer


    def training_step(self, batch, batch_idx):
        loss, acc, _,_,_,_ = self.forward(batch, mode="train")
        self.log('train_loss', loss,on_epoch=True,logger=True)
        self.log('train_acc', acc,on_epoch=True,logger=True)
        return loss


    def validation_step(self, batch, batch_idx):
        loss, acc, _,_,_,_ = self.forward(batch, mode="val")
        self.log('val_acc', acc,on_epoch=True,logger=True)
        self.log('val_loss', loss,on_epoch=True,logger=True)
        


    def test_step(self, batch, batch_idx):
        loss,acc, f1,precision, recall,preds = self.forward(batch, mode="test")
        self.log('test_acc', acc,on_epoch=True,logger=True)
        self.log('test_f1', f1,on_epoch=True,logger=True)
        self.log('test_precision', precision,on_epoch=True,logger=True)
        self.log('test_recall', recall,on_epoch=True,logger=True)


But I get the error:

TypeError: forward() takes 3 positional arguments but 4 were given

Could someone show me how to combine the two functions - I don’t think I need them to be separate as in the example, the first function is just generating a list of conv layers, and then the second function is adding a sequential layer at the end, I don’t see why they should be separate.

I did see this answer, my problem is that I understand that answer, I just struggle with my own one. Please let me know if there’s not enough information in the code either, thanks.

The issue is raised in:

x = self.model(x, edge_index, batch_idx)

since self.model is defined as:

self.model = GNNModel(**model_kwargs) 

which expects only two input arguments to its forward (3 if you count the self argument):

def forward(self, x, edge_index):

Could you explain why the batch_idx argument is used here and chck if removing it would work?

Thanks.

When I remove

x = self.model(x, edge_index) #removed batch_idx from the forward model of GraphGNN, and keep everything else the same, I get:

  File "pytorch_checking_all_sections.py", line 477, in forward
    x = self.model(x, edge_index) #removed batch_idx
  File "/home/miniconda3/envs/pytorch37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "pytorch_checking_all_sections.py", line 447, in forward
    x = self.GNN(x, edge_index)
  File "/home/miniconda3/envs/pytorch37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1178, in __getattr__
    type(self).__name__, name))
AttributeError: 'GNNModel' object has no attribute 'GNN'

I can understand this error.

In the old version (where the two functions are separate),

    self.GNN = GNNModel(c_in=c_in, 
                        c_hidden=c_hidden, 
                        c_out=c_hidden,
                        dp_rate_linear = dp_rate_linear, 
                        **kwargs)

Whereas now I have removed that expression, but yet I still have:

    x = self.GNN(x, edge_index)

in the combined GNNModel()…so I understand I’m trying to call something that is no longer defined, I’m just not what the alternative should be, like how do I say ‘take the model i’ve defined in the init section of the GNNModel()’ which is what this line is saying I guess.

Thanks.

You could directly use it, but note that your GNNModel.__init__ method is not defining self.GNN which is why the error is raised.