GNNExplainer.explain_graph() throws --> TypeError: MyGNN.forward() takes 3 positional arguments but 4 were given

I have a class implementing a bi-modal network (I deleted something below, just to show the general structure).

class MyGNN(torch.nn.Module):
  def __init__(self):
    super(MyGNN, self).__init__()

    self.gnn_emb = Sequential('x, edge_index, batch', 
            [
                (GATConv(in_channels=4, out_channels=256), 'x, edge_index -> x1'),
                nn.ReLU(inplace=True),           
                (global_max_pool, 'x1, batch -> x2'), 
                nn.Linear(128, 128),
                nn.ReLU()
            ]
        )

        self.ann_emb = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()       
        )

        self.fcn = nn.Sequential(
            nn.Linear(2*128, 128),
            nn.BatchNorm1d(128),
            nn.ELU(),
            nn.Dropout(p=0.1),
            nn.Linear(64, 1)
        )    

    def forward(self, cell, drug):
        drug_emb = self.ann_emb(drug)
        cell_emb = self.gnn_emb(cell.x.float(), cell.edge_index, cell.batch)
        concat = torch.cat([cell_emb, drug_emb], -1)
        y_pred = self.fcn(concat)
        y_pred = y_pred.reshape(y_pred.shape[0])
        return y_pred 

The gnn_emb part of the bi-modal model is using a GNN. The ann_emb part is using a feed forward network.

I have now build the model and would like to run torch-geometric’s GNNExplainer on it. Specifically the method explain_graph.

I initialized the GNNExplainer via

explainer = GNNExplainer(
    model=my_gnn_model,
    epochs=2,
    return_type='regression',
    nhops=2
)

And now for one sample (a specific cell-line and a specific drug) I am getting the below.

idx = 0
cl_drug_combi = my_gnn_model.val_loader.dataset[idx]
G = cl_drug_combi[0]       # cell-line graph for cell-line at index idx
drug = cl_drug_combi[1]    # drug at index idx  

Now I am trying to run the explain_graph method in the following way:

node_feat_mask, edge_mask = explainer.explain_graph(
    G.x,
    G.edge_index
)

However, this returns

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [146], in <cell line: 1>()
----> 1 node_feat_mask, edge_mask = explainer.explain_graph(
      2     G.x,
      3     G.edge_index,
      4 )

File ~/PATH/lib/python3.10/site-packages/torch_geometric/nn/models/gnn_explainer.py:146, in GNNExplainer.explain_graph(self, x, edge_index, **kwargs)
    143 batch = torch.zeros(x.shape[0], dtype=int, device=x.device)
    145 # Get the initial prediction.
--> 146 prediction = self.get_initial_prediction(x, edge_index, batch=batch,
    147                                          **kwargs)
    149 self._initialize_masks(x, edge_index)
    150 self.to(x.device)

File ~/PATH/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/anaconda3/envs/Thesis/lib/python3.10/site-packages/torch_geometric/nn/models/explainer.py:235, in Explainer.get_initial_prediction(self, x, edge_index, batch, **kwargs)
    231 @torch.no_grad()
    232 def get_initial_prediction(self, x: Tensor, edge_index: Tensor,
    233                            batch: Optional[Tensor] = None, **kwargs):
    234     if batch is not None:
--> 235         out = self.model(x, edge_index, batch, **kwargs)
    236     else:
    237         out = self.model(x, edge_index, **kwargs)

File ~/PATH/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

TypeError: MyGNN.forward() takes 3 positional arguments but 4 were given

So I guess the issue is that I need to provide the arguments in the forward method as well right? Meaning cell and drug?


I tried that via

node_feat_mask, edge_mask = explainer.explain_graph(
    G.x,
    G.edge_index,
    cell=G,
    drug=drug
)

However, this throws also an error, namely

...
TypeError: MyGNN.forward() got multiple values for argument 'cell'

I really don’t understand why it says that it got multiple values?

Can anyone help me?