AssertionError when calling pytorch geometrics GNNexplainer

I want to use pytorch geometrics GNNExplainer on my trained GNN. I defined the forward method in my GNN as

    def forward(self, 
                x, 
                edge_index, 
                cell_batch, 
                drug_x, 
                drug_edge_index,
                drug_batch):
        cell_emb = self.cell_emb(x, 
                                 edge_index, 
                                 cell_batch)
        drug_emb = self.drug_emb(drug_x, 
                                 drug_edge_index, 
                                 drug_batch)        
        concat = torch.cat([cell_emb, drug_emb], -1)
        y_pred = self.fcn(concat)
        y_pred = y_pred.reshape(y_pred.shape[0])
        return y_pred 

and I define the explainer with

import torch
from torch_geometric.explain import Explainer, GNNExplainer

device = torch.device('cpu')
model.model.to(device)

explainer = Explainer(
    model=model.model,
    algorithm=GNNExplainer(epochs=200,
                           return_type='regression',
                           nhops=2,
                           return_mask=True),
    explainer_config=dict(
        explanation_type='model',
        node_mask_type='attributes',
        edge_mask_type='object'
    ),
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    )
)

Now I tried to run the explainer in the following way

node_index = 0
explanation = explainer(cl.x.float(), 
                        cl.edge_index,      
                        cell_batch=cl.batch,
                        drug_x=dr.x.float(),
                        drug_edge_index=dr.edge_index,
                        drug_batch=dr.batch,
                        index=node_index)

but this throws me the following Assertion error

Cell In[12], line 249, in GraphGraph.forward(self, x, edge_index, cell_batch, drug_x, drug_edge_index, drug_batch)
    239 def forward(self, 
    240             x, 
    241             edge_index, 
   (...)
    244             drug_edge_index,
    245             drug_batch):
    246     cell_emb = self.cell_emb(x, 
    247                              edge_index, 
    248                              cell_batch)
--> 249     drug_emb = self.drug_emb(drug_x, 
    250                              drug_edge_index, 
    251                              drug_batch)        
    252     concat = torch.cat([cell_emb, drug_emb], -1)
    253     y_pred = self.fcn(concat)

File ~/anaconda3/envs/env280123/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File /tmp/ec2-user_pyg/tmpx2oh3fmp.py:18, in Sequential_819ab2.forward(self, x, edge_index, batch)
     16 def forward(self, x, edge_index, batch):
     17     """"""
---> 18     x1 = self.module_0(x, edge_index)
     19     x1 = self.module_1(x1)
     20     x1 = self.module_2(x1)

File ~/anaconda3/envs/env280123/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/env280123/lib/python3.10/site-packages/torch_geometric/nn/conv/gin_conv.py:74, in GINConv.forward(self, x, edge_index, size)
     71     x: OptPairTensor = (x, x)
     73 # propagate_type: (x: OptPairTensor)
---> 74 out = self.propagate(edge_index, x=x, size=size)
     76 x_r = x[1]
     77 if x_r is not None:

File ~/anaconda3/envs/env280123/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py:446, in MessagePassing.propagate(self, edge_index, size, **kwargs)
    443 if self.explain:
    444     explain_msg_kwargs = self.inspector.distribute(
    445         'explain_message', coll_dict)
--> 446     out = self.explain_message(out, **explain_msg_kwargs)
    448 aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
    449 for hook in self._aggregate_forward_pre_hooks.values():

File ~/anaconda3/envs/env280123/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py:559, in MessagePassing.explain_message(self, inputs, size_i)
    557     loop = edge_mask.new_ones(size_i)
    558     edge_mask = torch.cat([edge_mask, loop], dim=0)
--> 559 assert inputs.size(self.node_dim) == edge_mask.size(0)
    561 size = [1] * inputs.dim()
    562 size[self.node_dim] = -1

AssertionError: 

I think the important part of this message is in here

Cell In[12], line 249, in GraphGraph.forward(self, x, edge_index, cell_batch, drug_x, drug_edge_index, drug_batch)
    239 def forward(self, 
    240             x, 
    241             edge_index, 
   (...)
    244             drug_edge_index,
    245             drug_batch):
    246     cell_emb = self.cell_emb(x, 
    247                              edge_index, 
    248                              cell_batch)
--> 249     drug_emb = self.drug_emb(drug_x, 
    250                              drug_edge_index, 
    251                              drug_batch)        
    252     concat = torch.cat([cell_emb, drug_emb], -1)
    253     y_pred = self.fcn(concat)

where it tells me that I specified the input for the drug embedding wrongly.

I am pretty new to all this. I am pretty sure this is an easy solution for someone with experience. Can someone help me on how I need to specify the input to the explainer?