I want to use pytorch geometrics GNNExplainer on my trained GNN. I defined the forward method in my GNN as
def forward(self,
x,
edge_index,
cell_batch,
drug_x,
drug_edge_index,
drug_batch):
cell_emb = self.cell_emb(x,
edge_index,
cell_batch)
drug_emb = self.drug_emb(drug_x,
drug_edge_index,
drug_batch)
concat = torch.cat([cell_emb, drug_emb], -1)
y_pred = self.fcn(concat)
y_pred = y_pred.reshape(y_pred.shape[0])
return y_pred
and I define the explainer with
import torch
from torch_geometric.explain import Explainer, GNNExplainer
device = torch.device('cpu')
model.model.to(device)
explainer = Explainer(
model=model.model,
algorithm=GNNExplainer(epochs=200,
return_type='regression',
nhops=2,
return_mask=True),
explainer_config=dict(
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object'
),
model_config=dict(
mode='regression',
task_level='graph',
return_type='raw',
)
)
Now I tried to run the explainer in the following way
node_index = 0
explanation = explainer(cl.x.float(),
cl.edge_index,
cell_batch=cl.batch,
drug_x=dr.x.float(),
drug_edge_index=dr.edge_index,
drug_batch=dr.batch,
index=node_index)
but this throws me the following Assertion error
Cell In[12], line 249, in GraphGraph.forward(self, x, edge_index, cell_batch, drug_x, drug_edge_index, drug_batch)
239 def forward(self,
240 x,
241 edge_index,
(...)
244 drug_edge_index,
245 drug_batch):
246 cell_emb = self.cell_emb(x,
247 edge_index,
248 cell_batch)
--> 249 drug_emb = self.drug_emb(drug_x,
250 drug_edge_index,
251 drug_batch)
252 concat = torch.cat([cell_emb, drug_emb], -1)
253 y_pred = self.fcn(concat)
File ~/anaconda3/envs/env280123/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File /tmp/ec2-user_pyg/tmpx2oh3fmp.py:18, in Sequential_819ab2.forward(self, x, edge_index, batch)
16 def forward(self, x, edge_index, batch):
17 """"""
---> 18 x1 = self.module_0(x, edge_index)
19 x1 = self.module_1(x1)
20 x1 = self.module_2(x1)
File ~/anaconda3/envs/env280123/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/env280123/lib/python3.10/site-packages/torch_geometric/nn/conv/gin_conv.py:74, in GINConv.forward(self, x, edge_index, size)
71 x: OptPairTensor = (x, x)
73 # propagate_type: (x: OptPairTensor)
---> 74 out = self.propagate(edge_index, x=x, size=size)
76 x_r = x[1]
77 if x_r is not None:
File ~/anaconda3/envs/env280123/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py:446, in MessagePassing.propagate(self, edge_index, size, **kwargs)
443 if self.explain:
444 explain_msg_kwargs = self.inspector.distribute(
445 'explain_message', coll_dict)
--> 446 out = self.explain_message(out, **explain_msg_kwargs)
448 aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
449 for hook in self._aggregate_forward_pre_hooks.values():
File ~/anaconda3/envs/env280123/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py:559, in MessagePassing.explain_message(self, inputs, size_i)
557 loop = edge_mask.new_ones(size_i)
558 edge_mask = torch.cat([edge_mask, loop], dim=0)
--> 559 assert inputs.size(self.node_dim) == edge_mask.size(0)
561 size = [1] * inputs.dim()
562 size[self.node_dim] = -1
AssertionError:
I think the important part of this message is in here
Cell In[12], line 249, in GraphGraph.forward(self, x, edge_index, cell_batch, drug_x, drug_edge_index, drug_batch)
239 def forward(self,
240 x,
241 edge_index,
(...)
244 drug_edge_index,
245 drug_batch):
246 cell_emb = self.cell_emb(x,
247 edge_index,
248 cell_batch)
--> 249 drug_emb = self.drug_emb(drug_x,
250 drug_edge_index,
251 drug_batch)
252 concat = torch.cat([cell_emb, drug_emb], -1)
253 y_pred = self.fcn(concat)
where it tells me that I specified the input for the drug embedding wrongly.
I am pretty new to all this. I am pretty sure this is an easy solution for someone with experience. Can someone help me on how I need to specify the input to the explainer?