I have a class implementing a bi-modal network (I deleted something below, just to show the general structure).
class MyGNN(torch.nn.Module):
def __init__(self):
super(MyGNN, self).__init__()
self.gnn_emb = Sequential('x, edge_index, batch',
[
(GATConv(in_channels=4, out_channels=256), 'x, edge_index -> x1'),
nn.ReLU(inplace=True),
(global_max_pool, 'x1, batch -> x2'),
nn.Linear(128, 128),
nn.ReLU()
]
)
self.ann_emb = nn.Sequential(
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU()
)
self.fcn = nn.Sequential(
nn.Linear(2*128, 128),
nn.BatchNorm1d(128),
nn.ELU(),
nn.Dropout(p=0.1),
nn.Linear(64, 1)
)
def forward(self, cell, drug):
drug_emb = self.ann_emb(drug)
cell_emb = self.gnn_emb(cell.x.float(), cell.edge_index, cell.batch)
concat = torch.cat([cell_emb, drug_emb], -1)
y_pred = self.fcn(concat)
y_pred = y_pred.reshape(y_pred.shape[0])
return y_pred
The gnn_emb
part of the bi-modal model is using a GNN. The ann_emb
part is using a feed forward network.
I have now build the model and would like to run torch-geometric’s GNNExplainer on it. Specifically the method explain_graph
.
I initialized the GNNExplainer via
explainer = GNNExplainer(
model=my_gnn_model,
epochs=2,
return_type='regression',
nhops=2
)
And now for one sample (a specific cell-line and a specific drug) I am getting the below.
idx = 0
cl_drug_combi = my_gnn_model.val_loader.dataset[idx]
G = cl_drug_combi[0] # cell-line graph for cell-line at index idx
drug = cl_drug_combi[1] # drug at index idx
Now I am trying to run the explain_graph
method in the following way:
node_feat_mask, edge_mask = explainer.explain_graph(
G.x,
G.edge_index
)
However, this returns
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [146], in <cell line: 1>()
----> 1 node_feat_mask, edge_mask = explainer.explain_graph(
2 G.x,
3 G.edge_index,
4 )
File ~/PATH/lib/python3.10/site-packages/torch_geometric/nn/models/gnn_explainer.py:146, in GNNExplainer.explain_graph(self, x, edge_index, **kwargs)
143 batch = torch.zeros(x.shape[0], dtype=int, device=x.device)
145 # Get the initial prediction.
--> 146 prediction = self.get_initial_prediction(x, edge_index, batch=batch,
147 **kwargs)
149 self._initialize_masks(x, edge_index)
150 self.to(x.device)
File ~/PATH/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
File ~/anaconda3/envs/Thesis/lib/python3.10/site-packages/torch_geometric/nn/models/explainer.py:235, in Explainer.get_initial_prediction(self, x, edge_index, batch, **kwargs)
231 @torch.no_grad()
232 def get_initial_prediction(self, x: Tensor, edge_index: Tensor,
233 batch: Optional[Tensor] = None, **kwargs):
234 if batch is not None:
--> 235 out = self.model(x, edge_index, batch, **kwargs)
236 else:
237 out = self.model(x, edge_index, **kwargs)
File ~/PATH/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
TypeError: MyGNN.forward() takes 3 positional arguments but 4 were given
So I guess the issue is that I need to provide the arguments in the forward
method as well right? Meaning cell
and drug
?
I tried that via
node_feat_mask, edge_mask = explainer.explain_graph(
G.x,
G.edge_index,
cell=G,
drug=drug
)
However, this throws also an error, namely
...
TypeError: MyGNN.forward() got multiple values for argument 'cell'
I really don’t understand why it says that it got multiple values?
Can anyone help me?