I am trying to use batches with my heterogeneous model. I am not able to use global_mean_pool in the forward function since I get the error:
TraceError('symbolically traced variables cannot be used as inputs to control flow')
Tried using global_mean_pool as suggested here but I get the error:
'NodeStorage' object has no attribute 'max'
Instead, I was trying to use MeanAggregation() but now the model does not group by batches. Error is now:
ValueError: Expected input batch_size (1) to match target batch_size (2).
I reviewed this post but x.size
returns Proxy(getattr_2)
so I can’t see what size it needs to be.
The code works if the batch size is 1. How can I make it work with a larger batch size?
Here is a minimal reproducible example with some fake data:
from torch_geometric.data import HeteroData
from torch_geometric.nn import SAGEConv, to_hetero
import torch.nn.functional as F
from torch.nn import Linear, ReLU
from torch_geometric.loader import DataLoader
from torch_geometric.data import InMemoryDataset
from torch_geometric.nn import MeanAggregation
import torch_geometric.transforms as T
from torch.utils.data import Subset
import numpy as np
graphs = []
for i in range(10):
data = HeteroData()
data['node1'].x = torch.tensor([0, 1, 2]).reshape([-1, 1])
data['node1'].x = data['node1'].x.type(torch.FloatTensor)
data['node2'].x = torch.tensor([0, 1, 2]).reshape([-1, 1])
data['node2'].x = data['node2'].x.type(torch.FloatTensor)
data['node1', 'edge1', 'node2'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]])
data['node1', 'edge1', 'node2'].edge_index.type(torch.LongTensor)
data['node1', 'edge1', 'node2'].edge_attr =torch.tensor([123, 101.5]).reshape([-1, 1])
data['node2', 'edge2', 'node2'].edge_index = torch.tensor([[0, 1], [1, 2]])
data['node2', 'edge2', 'node2'].edge_index.type(torch.LongTensor)
data['node2', 'edge2', 'node2'].edge_attr = torch.tensor([1.23, 0.34]).reshape([-1, 1])
data.y = torch.tensor([1])
data.y = data.y.type(torch.LongTensor)
graphs.append(data)
transform = T.ToUndirected()
dataset = []
for graph in graphs:
graph = transform(graph)
dataset.append(graph)
data = dataset[0] # Get the first graph object.
data.validate()
print()
print(data)
print('=============================================================')
# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
torch.manual_seed(12345)
N = len(dataset)
# generate & shuffle indices
indices = np.arange(N)
indices = np.random.permutation(indices)
# select train/test with split 75/25
train_indices = indices [:int(0.75*N)]
test_indices = indices[int(0.75*N):]
train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1,-1), hidden_channels)
self.conv2 = SAGEConv((-1,-1), hidden_channels)
self.pool = MeanAggregation()
self.lin = Linear(hidden_channels, 11)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
x = self.pool(x, batch)
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x
model = GNN(hidden_channels=64, out_channels=1)
model = to_hetero(model, data.metadata(), aggr='sum')
def train():
model.train()
for batch in train_loader: # Iterate in batches over the training dataset.
out = model(batch.x_dict, batch.edge_index_dict, batch) # Perform a single forward pass.
loss = criterion(out, batch.y) # Compute the loss.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
optimizer.zero_grad() # Clear gradients.
def test(loader):
model.eval()
correct = 0
for batch in loader: # Iterate in batches over the training/test dataset.
out = model(batch.x_dict, batch.edge_index_dict, batch)
pred = out.argmax(dim=1) # Use the class with highest probability.
correct += int((pred == batch.y).sum()) # Check against ground-truth labels.
return correct / len(loader.dataset) # Derive ratio of correct predictions.
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(1, 21):
train()
train_acc = test(train_loader)
test_acc = test(test_loader)
print(f'Epoch: {epoch:02d}, Loss: {train_acc:.4f}, Val: {test_acc:.4f}')