I know this question has been asnwered a lot of times but for some reason all the tips I got didn’t work. Here’s the code, I am trying to use PGexplainer from pytorch geometric to explain graphs using a already trained model. Here’s the code:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
import torch.nn.functional as F
import torch.nn as nn
class GCN(nn.Module):
def init(self, input_dim, hidden_channels):
super(GCN, self).init()
torch.manual_seed(42)
self.conv1 = GCNConv(input_dim, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.conv3 = GCNConv(hidden_channels, hidden_channels)
self.lin = nn.Linear(hidden_channels, 3)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = self.conv3(x, edge_index)
x = gnn.global_mean_pool(x, batch)
x = self.lin(x)
return x
model = GCN(input_dim = 3, hidden_channels = 256)
model.to(device)
model.load_state_dict(torch.load(‘model1.pth’, map_location=device))
from torch_geometric.explain import Explainer, PGExplainer
model.double()
explainer = Explainer(
model=model,
algorithm=PGExplainer(epochs=10, lr=0.003),
explanation_type=‘phenomenon’,
edge_mask_type=‘object’,
model_config=dict(
mode=‘multiclass_classification’,
task_level=‘graph’,
return_type=‘raw’,
),
# Include only the top 10 most important edges:
threshold_config=dict(threshold_type=‘topk’, value=10),
)
PGExplainer needs to be trained separately since it is a parametric
explainer i.e it uses a neural network to generate explanations:
for epoch in range(10):
for data in val_dataloader:
# Move the data to the chosen device
data = data.to(device)
x = data.x
x = x.to(device)
edge_index = data.edge_index
edge_index = edge_index.to(device)
target = data.y
target = target.to(device)
batch = data.batch
batch = batch.to(device)
model.to(x.device)
# Pass the data to the model's train method
loss = explainer.algorithm.train(epoch, model, x, edge_index, target=target, batch=batch)
The error always occurs at the ‘loss =…’